Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow eta-expansion of methods with dependent types #10166

Merged
merged 1 commit into from Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/compiler/scala/tools/nsc/typechecker/ContextErrors.scala
Expand Up @@ -834,11 +834,6 @@ trait ContextErrors extends splain.SplainErrors {
setError(tree)
}

def DependentMethodTpeConversionToFunctionError(tree: Tree, tp: Type): Tree = {
issueNormalTypeError(tree, "method with dependent type " + tp + " cannot be converted to function value")
setError(tree)
}

// cases where we do not necessarily return trees

//checkStarPatOK
Expand Down
Expand Up @@ -24,7 +24,7 @@ import symtab.Flags._
trait EtaExpansion { self: Analyzer =>
import global._

/** Expand partial method application `p.f(es_1)...(es_n)`. Does not support dependent method types (yet).
/** Expand partial method application `p.f(es_1)...(es_n)`.
*
* We expand this to the following block, which evaluates
* the target of the application and its supplied arguments if needed (they are not stable),
Expand Down
9 changes: 3 additions & 6 deletions src/compiler/scala/tools/nsc/typechecker/Typers.scala
Expand Up @@ -3202,12 +3202,9 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
def typedEtaExpansion(tree: Tree, mode: Mode, pt: Type): Tree = {
debuglog(s"eta-expanding $tree: ${tree.tpe} to $pt")

if (tree.tpe.isDependentMethodType) DependentMethodTpeConversionToFunctionError(tree, tree.tpe) // TODO: support this
else {
val expansion = etaExpand(tree, context.owner)
if (context.undetparams.isEmpty) typed(expansion, mode, pt)
else instantiate(typed(expansion, mode), mode, pt)
}
val expansion = etaExpand(tree, context.owner)
if (context.undetparams.isEmpty) typed(expansion, mode, pt)
else instantiate(typed(expansion, mode), mode, pt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the restriction yet, but etaExpand still has adriaanm comment that not supported.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'll be off for a week 🇷🇴 and don't have more time right now to try to break it, will get back to it after.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enjoy! Insert joke about children as dependent types.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few more tests, AFAICT etaExpand works well with dependent method types

}

def typedRefinement(templ: Template): Unit = {
Expand Down
2 changes: 1 addition & 1 deletion src/reflect/scala/reflect/internal/Definitions.scala
Expand Up @@ -855,7 +855,7 @@ trait Definitions extends api.StandardDefinitions {
| was: $restpe
| now""")(methodToExpressionTp(restpe))
case mt @ MethodType(_, restpe) if mt.isImplicit => methodToExpressionTp(restpe)
case mt @ MethodType(_, restpe) if !mt.isDependentMethodType =>
case mt @ MethodType(_, restpe) =>
if (phase.erasedTypes) FunctionClass(mt.params.length).tpe
else functionType(mt.paramTypes, methodToExpressionTp(restpe))
case NullaryMethodType(restpe) => methodToExpressionTp(restpe)
Expand Down

This file was deleted.

This file was deleted.

48 changes: 48 additions & 0 deletions test/files/run/eta-dependent.check
@@ -0,0 +1,48 @@

scala> object defs {
val a = "obj"
def aa: a.type = a
def s = this
def f(x: Int): a.type = a
def g(x: Int)(y: x.type) = 0
def h(x: a.type)(y: a.type) = 0
}
object defs

scala> import defs._
import defs._

scala> val f1 = f _
val f1: Int => defs.a.type = <function>

scala> val f2: Int => a.type = f
val f2: Int => defs.a.type = <function>

scala> val f3: Int => Object = f
val f3: Int => Object = <function>

scala> val g1 = g(10) _
val g1: Int(10) => Int = <function>

scala> val g2: 10 => Int = g1
val g2: 10 => Int = <function>

scala> val g3: 11 => Int = g(11)
val g3: 11 => Int = <function>

scala> val g4: Int => Int = g(11) // mismatch
^
error: type mismatch;
found : Int(11) => Int
required: Int => Int

scala> val h1 = s.h(aa) _
val h1: defs.a.type => Int = <function>

scala> val h2: a.type => Int = h1
val h2: defs.a.type => Int = <function>

scala> val h3: a.type => Int = s.h(aa)
val h3: defs.a.type => Int = <function>

scala> :quit
72 changes: 72 additions & 0 deletions test/files/run/eta-dependent.scala
@@ -0,0 +1,72 @@
object NoMoreNeg {
def foo(x: AnyRef): x.type = x
val x: AnyRef => Any = foo
}

object t12641 {
def f(sb: StringBuilder) = Option("").foreach(sb.append)
}

object t12641a {
trait A {
def foo(s: String): this.type
def foo(s: Int): this.type
}
trait T {
val a1: A
val o: Option[String]

def t(a2: A): Unit = {
o.foreach(a1.foo)
o.foreach(a2.foo)

val f2: String => a2.type = a2.foo
val f3: String => A = a2.foo
}
}
}

object t12641b {
trait A {
def foo(s: String): this.type
}
trait T {
val a1: A
val o: Option[String]

def t(a2: A): Unit = {
o.foreach(a1.foo)
o.foreach(a2.foo)

val f1 = a2.foo _
val f2: String => a2.type = a2.foo
val f3: String => A = a2.foo
}
}
}

import scala.tools.partest._

object Test extends ReplTest with Lambdaless {
def code = """
object defs {
val a = "obj"
def aa: a.type = a
def s = this
def f(x: Int): a.type = a
def g(x: Int)(y: x.type) = 0
def h(x: a.type)(y: a.type) = 0
}
import defs._
val f1 = f _
val f2: Int => a.type = f
val f3: Int => Object = f
val g1 = g(10) _
val g2: 10 => Int = g1
val g3: 11 => Int = g(11)
val g4: Int => Int = g(11) // mismatch
val h1 = s.h(aa) _
val h2: a.type => Int = h1
val h3: a.type => Int = s.h(aa)
""".trim
}