Skip to content

Commit

Permalink
Allow eta-expansion of methods with dependent types
Browse files Browse the repository at this point in the history
  • Loading branch information
lrytz committed Oct 10, 2022
1 parent 8c6b49b commit cc11e8e
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 22 deletions.
5 changes: 0 additions & 5 deletions src/compiler/scala/tools/nsc/typechecker/ContextErrors.scala
Expand Up @@ -834,11 +834,6 @@ trait ContextErrors extends splain.SplainErrors {
setError(tree)
}

def DependentMethodTpeConversionToFunctionError(tree: Tree, tp: Type): Tree = {
issueNormalTypeError(tree, "method with dependent type " + tp + " cannot be converted to function value")
setError(tree)
}

// cases where we do not necessarily return trees

//checkStarPatOK
Expand Down
Expand Up @@ -24,7 +24,7 @@ import symtab.Flags._
trait EtaExpansion { self: Analyzer =>
import global._

/** Expand partial method application `p.f(es_1)...(es_n)`. Does not support dependent method types (yet).
/** Expand partial method application `p.f(es_1)...(es_n)`.
*
* We expand this to the following block, which evaluates
* the target of the application and its supplied arguments if needed (they are not stable),
Expand Down
9 changes: 3 additions & 6 deletions src/compiler/scala/tools/nsc/typechecker/Typers.scala
Expand Up @@ -3202,12 +3202,9 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
def typedEtaExpansion(tree: Tree, mode: Mode, pt: Type): Tree = {
debuglog(s"eta-expanding $tree: ${tree.tpe} to $pt")

if (tree.tpe.isDependentMethodType) DependentMethodTpeConversionToFunctionError(tree, tree.tpe) // TODO: support this
else {
val expansion = etaExpand(tree, context.owner)
if (context.undetparams.isEmpty) typed(expansion, mode, pt)
else instantiate(typed(expansion, mode), mode, pt)
}
val expansion = etaExpand(tree, context.owner)
if (context.undetparams.isEmpty) typed(expansion, mode, pt)
else instantiate(typed(expansion, mode), mode, pt)
}

def typedRefinement(templ: Template): Unit = {
Expand Down
2 changes: 1 addition & 1 deletion src/reflect/scala/reflect/internal/Definitions.scala
Expand Up @@ -855,7 +855,7 @@ trait Definitions extends api.StandardDefinitions {
| was: $restpe
| now""")(methodToExpressionTp(restpe))
case mt @ MethodType(_, restpe) if mt.isImplicit => methodToExpressionTp(restpe)
case mt @ MethodType(_, restpe) if !mt.isDependentMethodType =>
case mt @ MethodType(_, restpe) =>
if (phase.erasedTypes) FunctionClass(mt.params.length).tpe
else functionType(mt.paramTypes, methodToExpressionTp(restpe))
case NullaryMethodType(restpe) => methodToExpressionTp(restpe)
Expand Down

This file was deleted.

This file was deleted.

48 changes: 48 additions & 0 deletions test/files/run/eta-dependent.check
@@ -0,0 +1,48 @@

scala> object defs {
val a = "obj"
def aa: a.type = a
def s = this
def f(x: Int): a.type = a
def g(x: Int)(y: x.type) = 0
def h(x: a.type)(y: a.type) = 0
}
object defs

scala> import defs._
import defs._

scala> val f1 = f _
val f1: Int => defs.a.type = <function>

scala> val f2: Int => a.type = f
val f2: Int => defs.a.type = <function>

scala> val f3: Int => Object = f
val f3: Int => Object = <function>

scala> val g1 = g(10) _
val g1: Int(10) => Int = <function>

scala> val g2: 10 => Int = g1
val g2: 10 => Int = <function>

scala> val g3: 11 => Int = g(11)
val g3: 11 => Int = <function>

scala> val g4: Int => Int = g(11) // mismatch
^
error: type mismatch;
found : Int(11) => Int
required: Int => Int

scala> val h1 = s.h(aa) _
val h1: defs.a.type => Int = <function>

scala> val h2: a.type => Int = h1
val h2: defs.a.type => Int = <function>

scala> val h3: a.type => Int = s.h(aa)
val h3: defs.a.type => Int = <function>

scala> :quit
72 changes: 72 additions & 0 deletions test/files/run/eta-dependent.scala
@@ -0,0 +1,72 @@
object NoMoreNeg {
def foo(x: AnyRef): x.type = x
val x: AnyRef => Any = foo
}

object t12641 {
def f(sb: StringBuilder) = Option("").foreach(sb.append)
}

object t12641a {
trait A {
def foo(s: String): this.type
def foo(s: Int): this.type
}
trait T {
val a1: A
val o: Option[String]

def t(a2: A): Unit = {
o.foreach(a1.foo)
o.foreach(a2.foo)

val f2: String => a2.type = a2.foo
val f3: String => A = a2.foo
}
}
}

object t12641b {
trait A {
def foo(s: String): this.type
}
trait T {
val a1: A
val o: Option[String]

def t(a2: A): Unit = {
o.foreach(a1.foo)
o.foreach(a2.foo)

val f1 = a2.foo _
val f2: String => a2.type = a2.foo
val f3: String => A = a2.foo
}
}
}

import scala.tools.partest._

object Test extends ReplTest with Lambdaless {
def code = """
object defs {
val a = "obj"
def aa: a.type = a
def s = this
def f(x: Int): a.type = a
def g(x: Int)(y: x.type) = 0
def h(x: a.type)(y: a.type) = 0
}
import defs._
val f1 = f _
val f2: Int => a.type = f
val f3: Int => Object = f
val g1 = g(10) _
val g2: 10 => Int = g1
val g3: 11 => Int = g(11)
val g4: Int => Int = g(11) // mismatch
val h1 = s.h(aa) _
val h2: a.type => Int = h1
val h3: a.type => Int = s.h(aa)
""".trim
}

0 comments on commit cc11e8e

Please sign in to comment.