Skip to content

Commit

Permalink
Fix improper usage of constrained breaking type inference
Browse files Browse the repository at this point in the history
In multiple places, we had code equivalent to the following pattern:

    val (tl2, targs) = constrained(tl)
    tl2.resultType <:< ...

which lead to subtype checks directly involving the TypeParamRefs of the
constrained type lambda. This commit uses the following pattern instead:

    val (tl2, targs) = constrained(tl)
    tl2.instantiate(targs.map(_.tpe)) <:< ...

which substitutes the TypeParamRefs by the corresponding TypeVars in the
constraint. This is necessary because when comparing
TypeParamRefs in isSubType:
- we only recurse on the bounds of the TypeParamRef using
  `isSubTypeWhenFrozen` which prevents further constraints from being
  added (see the added stm.scala test case for an example where this
  matters).
- if the corresponding TypeVar is instantiated and the TyperState has
  been gc()'ed, there is no way to find the instantiation corresponding
  to the current TypeParamRef anymore.

There is one place where I left the old logic intact:
`TrackingTypeComparer#matchCase` because the match type caching
logic (in `MatchType#reduced`) conflicted with the use of TypeVars since
it retracts the current TyperState.

This change breaks a test which involves an unlikely combination of
implicit conversion, overloading and apply insertion. Given that there
is always a tension between type inference and implicit conversion, and
that we're discouraging uses of implicit conversions, I think that's an
acceptable trade-off.
  • Loading branch information
smarter committed Oct 20, 2021
1 parent faf9538 commit 0182e06
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 13 deletions.
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Expand Up @@ -411,7 +411,7 @@ trait Applications extends Compatibility {
*/
@threadUnsafe lazy val methType: Type = liftedFunType.widen match {
case funType: MethodType => funType
case funType: PolyType => constrained(funType).resultType
case funType: PolyType => instantiateWithTypeVars(funType)
case tp => tp //was: funType
}

Expand Down Expand Up @@ -1571,7 +1571,7 @@ trait Applications extends Compatibility {
case tp2: MethodType => true // (3a)
case tp2: PolyType if tp2.resultType.isInstanceOf[MethodType] => true // (3a)
case tp2: PolyType => // (3b)
explore(isAsSpecificValueType(tp1, constrained(tp2).resultType))
explore(isAsSpecificValueType(tp1, instantiateWithTypeVars(tp2)))
case _ => // 3b)
isAsSpecificValueType(tp1, tp2)
}
Expand Down Expand Up @@ -1738,7 +1738,7 @@ trait Applications extends Compatibility {
resultType.revealIgnored match {
case resultType: ValueType =>
altType.widen match {
case tp: PolyType => resultConforms(altSym, constrained(tp).resultType, resultType)
case tp: PolyType => resultConforms(altSym, instantiateWithTypeVars(tp), resultType)
case tp: MethodType => constrainResult(altSym, tp.resultType, resultType)
case _ => true
}
Expand Down
9 changes: 7 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala
Expand Up @@ -659,6 +659,11 @@ object ProtoTypes {
def constrained(tl: TypeLambda)(using Context): TypeLambda =
constrained(tl, EmptyTree)._1

/** Instantiate `tl` with fresh type variables added to the constraint. */
def instantiateWithTypeVars(tl: TypeLambda)(using Context): Type =
val targs = constrained(tl, ast.tpd.EmptyTree, alwaysAddTypeVars = true)._2
tl.instantiate(targs.tpes)

/** A new type variable with given bounds for its origin.
* @param represents If exists, the TermParamRef that the TypeVar represents
* in the substitution generated by `resultTypeApprox`
Expand Down Expand Up @@ -707,7 +712,7 @@ object ProtoTypes {
else mt.resultType

/** The normalized form of a type
* - unwraps polymorphic types, tracking their parameters in the current constraint
* - instantiate polymorphic types with fresh type variables in the current constraint
* - skips implicit parameters of methods and functions;
* if result type depends on implicit parameter, replace with wildcard.
* - converts non-dependent method types to the corresponding function types
Expand All @@ -726,7 +731,7 @@ object ProtoTypes {
Stats.record("normalize")
tp.widenSingleton match {
case poly: PolyType =>
normalize(constrained(poly).resultType, pt)
normalize(instantiateWithTypeVars(poly), pt)
case mt: MethodType =>
if (mt.isImplicitMethod) normalize(resultTypeApprox(mt, wildcardOnly = true), pt)
else if (mt.isResultDependent) tp
Expand Down
9 changes: 5 additions & 4 deletions compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala
Expand Up @@ -7,6 +7,7 @@ import dotty.tools.dotc.core.Contexts.{*, given}
import dotty.tools.dotc.core.Decorators.{*, given}
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.ast.tpd.*
import dotty.tools.dotc.typer.ProtoTypes.constrained

import org.junit.Test
Expand All @@ -18,8 +19,8 @@ class ConstraintsTest:
@Test def mergeParamsTransitivity: Unit =
inCompilerContext(TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S, T, R]: Any }") {
val tp = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])
val List(s, t, r) = tp.paramRefs
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val List(s, t, r) = tvars.tpes

val innerCtx = ctx.fresh.setExploreTyperState()
inContext(innerCtx) {
Expand All @@ -37,8 +38,8 @@ class ConstraintsTest:
@Test def mergeBoundsTransitivity: Unit =
inCompilerContext(TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S, T]: Any }") {
val tp = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])
val List(s, t) = tp.paramRefs
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val List(s, t) = tvars.tpes

val innerCtx = ctx.fresh.setExploreTyperState()
inContext(innerCtx) {
Expand Down
10 changes: 10 additions & 0 deletions tests/pos/stm.scala
@@ -0,0 +1,10 @@
class Inv[X]
class Ref[X]
object Ref {
def apply(i: Inv[Int], x: Int): Ref[Int] = ???
def apply[Y](i: Inv[Y], x: Y): Ref[Y] = ???
}

class A {
val ref: Ref[List[AnyRef]] = Ref(new Inv[List[AnyRef]], List.empty)
}
3 changes: 1 addition & 2 deletions tests/pos/t0851.scala
@@ -1,9 +1,8 @@
package test

object test1 {
case class Foo[T,T2](f : (T,T2) => String) extends (((T,T2)) => String){
case class Foo[T,T2](f : (T,T2) => String) {
def apply(t : T) = (s:T2) => f(t,s)
def apply(p : (T,T2)) = f(p._1,p._2)
}
implicit def g[T](f : (T,String) => String): Foo[T, String] = Foo(f)
def main(args : Array[String]) : Unit = {
Expand Down
3 changes: 1 addition & 2 deletions tests/pos/t2913.scala
Expand Up @@ -33,9 +33,8 @@ object TestNoAutoTupling {

// t0851 is essentially the same:
object test1 {
case class Foo[T,T2](f : (T,T2) => String) extends (((T,T2)) => String){
case class Foo[T,T2](f : (T,T2) => String) {
def apply(t : T) = (s:T2) => f(t,s)
def apply(p : (T,T2)) = f(p._1,p._2)
}
implicit def g[T](f : (T,String) => String): test1.Foo[T,String] = Foo(f)
def main(args : Array[String]) : Unit = {
Expand Down

0 comments on commit 0182e06

Please sign in to comment.