Skip to content

Commit

Permalink
Merge pull request #8545 from joroKr21/variance-holes
Browse files Browse the repository at this point in the history
Plug many variance holes (pos and neg)
  • Loading branch information
lrytz committed Mar 19, 2020
2 parents bd1ee19 + a29da6a commit ae10be4
Show file tree
Hide file tree
Showing 11 changed files with 257 additions and 63 deletions.
12 changes: 4 additions & 8 deletions src/compiler/scala/tools/nsc/typechecker/RefChecks.scala
Expand Up @@ -890,9 +890,9 @@ abstract class RefChecks extends Transform {
case ClassInfoType(parents, _, clazz) => "supertype "+intersectionType(parents, clazz.owner)
case _ => "type "+tp
}
override def issueVarianceError(base: Symbol, sym: Symbol, required: Variance): Unit = {
override def issueVarianceError(base: Symbol, sym: Symbol, required: Variance, tpe: Type): Unit = {
reporter.error(base.pos,
s"${sym.variance} $sym occurs in $required position in ${tpString(base.info)} of $base")
s"${sym.variance} $sym occurs in $required position in ${tpString(tpe)} of $base")
}
}

Expand Down Expand Up @@ -1563,8 +1563,6 @@ abstract class RefChecks extends Transform {

if (!sym.exists)
devWarning("Select node has NoSymbol! " + tree + " / " + tree.tpe)
else if (sym.isLocalToThis)
varianceValidator.checkForEscape(sym, currentClass)

def checkSuper(mix: Name) =
// term should have been eliminated by super accessors
Expand Down Expand Up @@ -1771,13 +1769,11 @@ abstract class RefChecks extends Transform {
result.transform(this)
}
result1 match {
case ClassDef(_, _, _, _)
| TypeDef(_, _, _, _)
| ModuleDef(_, _, _) =>
case ClassDef(_, _, _, _) | TypeDef(_, _, _, _) | ModuleDef(_, _, _) =>
if (result1.symbol.isLocalToBlock || result1.symbol.isTopLevel)
varianceValidator.traverse(result1)
case tt @ TypeTree() if tt.original != null =>
varianceValidator.traverse(tt.original) // See scala/bug#7872
varianceValidator.validateVarianceOfPolyTypesIn(tt.tpe)
case _ =>
}

Expand Down
160 changes: 112 additions & 48 deletions src/reflect/scala/reflect/internal/Variances.scala
Expand Up @@ -28,7 +28,6 @@ trait Variances {
* TODO - eliminate duplication with varianceInType
*/
class VarianceValidator extends InternalTraverser {
private[this] val escapedLocals = mutable.HashSet[Symbol]()
// A flag for when we're in a refinement, meaning method parameter types
// need to be checked.
private[this] var inRefinement = false
Expand All @@ -38,34 +37,26 @@ trait Variances {
try body finally inRefinement = saved
}

/** Is every symbol in the owner chain between `site` and the owner of `sym`
* either a term symbol or private[this]? If not, add `sym` to the set of
* escaped locals.
* @pre sym.isLocalToThis
*/
@tailrec final def checkForEscape(sym: Symbol, site: Symbol): Unit = {
if (site == sym.owner || site == sym.owner.moduleClass || site.hasPackageFlag) () // done
else if (site.isTerm || site.isPrivateLocal) checkForEscape(sym, site.owner) // ok - recurse to owner
else escapedLocals += sym
}

protected def issueVarianceError(base: Symbol, sym: Symbol, required: Variance): Unit = ()
protected def issueVarianceError(base: Symbol, sym: Symbol, required: Variance, tpe: Type): Unit = ()

// Flip occurrences of type parameters and parameters, unless
// - it's a constructor, or case class factory or extractor
// - it's a type parameter / parameter of a local definition
// - it's a type parameter of tvar's owner.
def shouldFlip(sym: Symbol, tvar: Symbol) = (
sym.isParameter
&& !sym.owner.isLocalToThis
&& !(tvar.isTypeParameterOrSkolem && sym.isTypeParameterOrSkolem && tvar.owner == sym.owner)
)

// Is `sym` is local to a term or is private[this] or protected[this]?
def isExemptFromVariance(sym: Symbol): Boolean = !sym.owner.isClass || (
(sym.isLocalToThis || sym.isSuperAccessor) // super accessors are implicitly local #4345
&& !escapedLocals(sym)
)
def isExemptFromVariance(sym: Symbol): Boolean =
// super accessors are implicitly local #4345
!sym.owner.isClass || sym.isLocalToThis || sym.isSuperAccessor

private object ValidateVarianceMap extends VariancedTypeMap {
private[this] var base: Symbol = _
private[this] var inLowerBoundOf: Symbol = _

/** The variance of a symbol occurrence of `tvar` seen at the level of the definition of `base`.
* The search proceeds from `base` to the owner of `tvar`.
Expand All @@ -85,12 +76,16 @@ trait Variances {
else v

@tailrec
def loop(sym: Symbol, v: Variance): Variance = (
if (sym == tvar.owner || v.isBivariant) v
def loop(sym: Symbol, v: Variance): Variance =
if (v.isBivariant) v
else if (sym == tvar.owner)
// We can't move this to `shouldFlip`, because it's needed only once at the end.
if (inLowerBoundOf == sym) v.flip else v
else loop(sym.owner, nextVariance(sym, v))
)

loop(base, Covariant)
}

def isUncheckedVariance(tp: Type) = tp match {
case AnnotatedType(annots, _) => annots exists (_ matches definitions.uncheckedVarianceClass)
case _ => false
Expand All @@ -104,17 +99,19 @@ trait Variances {
def base_s = s"$base in ${base.owner}" + (if (base.owner.isClass) "" else " in " + base.owner.enclClass)
log(s"verifying $sym_s is $required at $base_s")
if (sym.variance != required)
issueVarianceError(base, sym, required)
issueVarianceError(base, sym, required, base.info)
}
}

override def mapOver(decls: Scope): Scope = {
decls foreach (sym => withVariance(if (sym.isAliasType) Invariant else variance)(this(sym.info)))
decls
}

private def resultTypeOnly(tp: Type) = tp match {
case mt: MethodType => !inRefinement
case pt: PolyType => true
case _ => false
case _: MethodType => !inRefinement
case _: PolyType => true
case _ => false
}

/** For PolyTypes, type parameters are skipped because they are defined
Expand All @@ -125,75 +122,141 @@ trait Variances {
def apply(tp: Type): Type = {
tp match {
case _ if isUncheckedVariance(tp) =>
case _ if resultTypeOnly(tp) => this(tp.resultType)
case TypeRef(_, sym, _) if shouldDealias(sym) => this(tp.normalize)
case TypeRef(_, sym, _) if !sym.variance.isInvariant => checkVarianceOfSymbol(sym) ; tp.mapOver(this)
case _ if resultTypeOnly(tp) => apply(tp.resultType)
case TypeRef(_, sym, _) if shouldDealias(sym) => apply(tp.normalize)
case TypeRef(_, sym, _) if !sym.variance.isInvariant => checkVarianceOfSymbol(sym); tp.mapOver(this)
case RefinedType(_, _) => withinRefinement(tp.mapOver(this))
case ClassInfoType(parents, _, _) => parents foreach this
case mt @ MethodType(_, result) => flipped(mt.paramTypes foreach this) ; this(result)
case ClassInfoType(parents, _, _) => parents.foreach(apply)
case mt @ MethodType(_, result) => flipped(mt.paramTypes.foreach(apply)); apply(result)
case _ => tp.mapOver(this)
}
// We're using TypeMap here for type traversal only. To avoid wasteful symbol
// cloning during the recursion, it is important to return the input `tp`, rather
// than the result of the pattern match above, which normalizes types.
tp
}

private def shouldDealias(sym: Symbol): Boolean = {
// The RHS of (private|protected)[this] type aliases are excluded from variance checks. This is
// implemented in relativeVariance.
// As such, we need to expand references to them to retain soundness. Example: neg/t8079a.scala
sym.isAliasType && isExemptFromVariance(sym)
}

/** Validate the variance of types in the definition of `base`.
*
* Traverse the type signature of `base` and for each type parameter:
* - Calculate the relative variance between `base` and the type parameter's owner by
* walking the owner chain of `base`.
* - Calculate the required variance of the type parameter which is the product of the
* relative variance and the current variance in the type signature of `base`.
* - Ensure that the declared variance of the type parameter is compatible with the
* required variance, otherwise issue an error.
*
* Lower bounds need special handling. By default the variance is flipped when entering a
* lower bound. In most cases this is the correct behaviour except for the type parameters
* of higher-kinded types. E.g. in `Foo` below `x` occurs in covariant position:
* `class Foo[F[+_]] { type G[+x] >: F[x] }`
*
* To handle this special case, track when entering the lower bound of a HKT in a variable
* and flip the relative variance for its type parameters. (flipping the variance a second
* time negates the first flip).
*/
def validateDefinition(base: Symbol): Unit = {
val saved = this.base
this.base = base
try apply(base.info)
finally this.base = saved
base.info match {
case PolyType(_, TypeBounds(lo, hi)) =>
inLowerBoundOf = base
try flipped(apply(lo))
finally inLowerBoundOf = null
apply(hi)
case other =>
apply(other)
}
}
}

/** Validate variance of info of symbol `base` */
private def validateVariance(base: Symbol): Unit = {
ValidateVarianceMap validateDefinition base
private object PolyTypeVarianceMap extends TypeMap {

private def ownerOf(pt: PolyType): Symbol =
pt.typeParams.head.owner

private def checkPolyTypeParam(pt: PolyType, tparam: Symbol, tpe: Type): Unit =
if (!tparam.isInvariant) {
val required = varianceInType(tpe)(tparam)
if (!required.isBivariant && tparam.variance != required)
issueVarianceError(ownerOf(pt), tparam, required, pt)
}

def apply(tp: Type): Type = {
tp match {
case pt @ PolyType(typeParams, TypeBounds(lo, hi)) =>
typeParams.foreach { tparam =>
checkPolyTypeParam(pt, tparam, lo)
checkPolyTypeParam(pt, tparam, hi)
}

pt.mapOver(this)

case pt @ PolyType(typeParams, resultType) =>
typeParams.foreach(checkPolyTypeParam(pt, _, resultType))
pt.mapOver(this)

case _ =>
tp.mapOver(this)
}

tp
}
}

/** Validate the variance of (the type parameters of) PolyTypes in `tpe`.
*
* `validateDefinition` cannot handle PolyTypes in arbitrary position, because in general
* the relative variance of such types cannot be computed by walking the owner chain.
*
* Instead this method applies a naive algorithm which is correct but less efficient:
* use `varianceInType` to check each type parameter of a PolyType separately.
*/
def validateVarianceOfPolyTypesIn(tpe: Type): Unit =
PolyTypeVarianceMap(tpe)

override def traverse(tree: Tree): Unit = {
def sym = tree.symbol
// No variance check for object-private/protected methods/values.
// Or constructors, or case class factory or extractor.
def skip = (
sym == NoSymbol
|| sym.isLocalToThis
|| sym.owner.isConstructor
|| sym.owner.isCaseApplyOrUnapply
|| sym.owner.isConstructor // FIXME: this is unsound - scala/bug#8737
|| sym.owner.isCaseApplyOrUnapply // same treatment as constructors
|| sym.isParamAccessor && sym.isLocalToThis // local class parameters are construction only
)

tree match {
case defn: MemberDef if skip =>
case _: MemberDef if skip =>
debuglog(s"Skipping variance check of ${sym.defString}")
case ClassDef(_, _, _, _) | TypeDef(_, _, _, _) =>
validateVariance(sym)
ValidateVarianceMap.validateDefinition(sym)
tree.traverse(this)
case ModuleDef(_, _, _) =>
validateVariance(sym.moduleClass)
ValidateVarianceMap.validateDefinition(sym.moduleClass)
tree.traverse(this)
case ValDef(_, _, _, _) =>
validateVariance(sym)
ValidateVarianceMap.validateDefinition(sym)
case DefDef(_, _, tparams, vparamss, _, _) =>
validateVariance(sym)
ValidateVarianceMap.validateDefinition(sym)
traverseTrees(tparams)
traverseTreess(vparamss)
case Template(_, _, _) =>
tree.traverse(this)
case CompoundTypeTree(templ) =>
case CompoundTypeTree(_) =>
tree.traverse(this)

// scala/bug#7872 These two cases make sure we don't miss variance exploits
// in originals, e.g. in `foo[({type l[+a] = List[a]})#l]`
case tt @ TypeTree() if tt.original != null =>
tt.original.traverse(this)
case tt : TypTree =>
case tt: TypTree =>
tt.traverse(this)

case _ =>
}
}
Expand Down Expand Up @@ -230,7 +293,7 @@ trait Variances {
case ThisType(_) | ConstantType(_) => Bivariant
case TypeRef(_, tparam, _) if tparam eq this.tparam => Covariant
case NullaryMethodType(restpe) => inType(restpe)
case SingleType(pre, sym) => inType(pre)
case SingleType(pre, _) => inType(pre)
case TypeRef(pre, _, _) if tp.isHigherKinded => inType(pre) // a type constructor cannot occur in tp's args
case TypeRef(pre, sym, args) => inType(pre) & inArgs(sym, args)
case TypeBounds(lo, hi) => inType(lo).flip & inType(hi)
Expand All @@ -239,6 +302,7 @@ trait Variances {
case PolyType(tparams, restpe) => inSyms(tparams).flip & inType(restpe)
case ExistentialType(tparams, restpe) => inSyms(tparams) & inType(restpe)
case AnnotatedType(annots, tp) => inAnnots(annots) & inType(tp)
case SuperType(thistpe, supertpe) => inType(thistpe) & inType(supertpe)
}

def apply(tp: Type, tparam: Symbol): Variance = {
Expand Down
2 changes: 1 addition & 1 deletion test/files/neg/t7872.check
@@ -1,7 +1,7 @@
t7872.scala:6: error: contravariant type a occurs in covariant position in type [-a]Cov[a] of type l
type x = {type l[-a] = Cov[a]}
^
t7872.scala:8: error: covariant type a occurs in contravariant position in type [+a]Inv[a] of type l
t7872.scala:8: error: covariant type a occurs in contravariant position in type [+a]Inv[a] of value <local l>
foo[({type l[+a] = Inv[a]})#l]
^
t7872.scala:5: error: contravariant type a occurs in covariant position in type [-a]Cov[a] of type l
Expand Down
4 changes: 2 additions & 2 deletions test/files/neg/t7872b.check
@@ -1,7 +1,7 @@
t7872b.scala:8: error: contravariant type a occurs in covariant position in type [-a]List[a] of type l
t7872b.scala:8: error: contravariant type a occurs in covariant position in type [-a]List[a] of value <local l>
def oops1 = down[({type l[-a] = List[a]})#l](List('whatever: Object)).head + "oops"
^
t7872b.scala:19: error: covariant type a occurs in contravariant position in type [+a]coinv.Stringer[a] of type l
t7872b.scala:19: error: covariant type a occurs in contravariant position in type [+a]a => String of value <local l>
def oops2 = up[({type l[+a] = Stringer[a]})#l]("printed: " + _)
^
2 errors
4 changes: 4 additions & 0 deletions test/files/neg/t9911.check
@@ -0,0 +1,4 @@
t9911.scala:23: error: super may not be used on value source
super.source.getSomething
^
1 error
28 changes: 28 additions & 0 deletions test/files/neg/t9911.scala
@@ -0,0 +1,28 @@
// This should say:
// Error: super may not be used on value source
class ScalacBug {

class SomeClass {

type U

// Changing T or U stops the problem
def getSomething[T]: U = ???
}

trait Base {

// Changing this to a def like it should be stops the problem
val source: SomeClass = ???
}

class Bug extends Base {

override val source = {
// Not calling the function stops the problem
super.source.getSomething
???
}
}

}
22 changes: 22 additions & 0 deletions test/files/neg/variance-holes.check
@@ -0,0 +1,22 @@
variance-holes.scala:9: error: covariant type x occurs in contravariant position in type [+x, +y] >: F[x,y] of type F2
def asWiden[F2[+x, +y] >: F[x, y]]: F2[Int, Int] = v
^
variance-holes.scala:2: error: contravariant type A occurs in covariant position in type [-A] >: List[A] of type Lower1
type Lower1[-A] >: List[A]
^
variance-holes.scala:5: error: covariant type x occurs in contravariant position in type [+x] >: F[x] of type G
type G[+x] >: F[x]
^
variance-holes.scala:13: error: covariant type A occurs in contravariant position in type AnyRef{type T >: A} of method foo
def foo: { type T >: A }
^
variance-holes.scala:17: error: covariant type A occurs in contravariant position in type AnyRef{type T <: A} of value x
def foo(x: { type T <: A }): Unit
^
variance-holes.scala:20: error: covariant type A occurs in contravariant position in type <: AnyRef{type T >: A} of type x
class RefinedLower[+A, x <: { type T >: A }]
^
variance-holes.scala:21: error: covariant type A occurs in contravariant position in type A of value x_=
private[this] class PrivateThis[+A](var x: A)
^
7 errors

0 comments on commit ae10be4

Please sign in to comment.