Skip to content

Commit

Permalink
Support records in JavaParsers (#16762)
Browse files Browse the repository at this point in the history
This is a port of scala/scala#9551.

Fixes #14846.
  • Loading branch information
TheElectronWill committed May 21, 2023
2 parents b67e269 + da4996a commit 4f3632f
Show file tree
Hide file tree
Showing 12 changed files with 208 additions and 20 deletions.
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Expand Up @@ -688,6 +688,7 @@ class Definitions {
@tu lazy val JavaCalendarClass: ClassSymbol = requiredClass("java.util.Calendar")
@tu lazy val JavaDateClass: ClassSymbol = requiredClass("java.util.Date")
@tu lazy val JavaFormattableClass: ClassSymbol = requiredClass("java.util.Formattable")
@tu lazy val JavaRecordClass: Symbol = getClassIfDefined("java.lang.Record")

@tu lazy val JavaEnumClass: ClassSymbol = {
val cls = requiredClass("java.lang.Enum")
Expand Down
7 changes: 7 additions & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Expand Up @@ -204,6 +204,7 @@ object StdNames {
final val Null: N = "Null"
final val Object: N = "Object"
final val FromJavaObject: N = "<FromJavaObject>"
final val Record: N = "Record"
final val Product: N = "Product"
final val PartialFunction: N = "PartialFunction"
final val PrefixType: N = "PrefixType"
Expand Down Expand Up @@ -913,6 +914,10 @@ object StdNames {
final val VOLATILEkw: N = kw("volatile")
final val WHILEkw: N = kw("while")

final val RECORDid: N = "record"
final val VARid: N = "var"
final val YIELDid: N = "yield"

final val BoxedBoolean: N = "java.lang.Boolean"
final val BoxedByte: N = "java.lang.Byte"
final val BoxedCharacter: N = "java.lang.Character"
Expand Down Expand Up @@ -945,6 +950,8 @@ object StdNames {
final val JavaSerializable: N = "java.io.Serializable"
}



class JavaTermNames extends JavaNames[TermName] {
protected def fromString(s: String): TermName = termName(s)
}
Expand Down
85 changes: 77 additions & 8 deletions compiler/src/dotty/tools/dotc/parsing/JavaParsers.scala
Expand Up @@ -20,7 +20,8 @@ import StdNames._
import reporting._
import dotty.tools.dotc.util.SourceFile
import util.Spans._
import scala.collection.mutable.ListBuffer

import scala.collection.mutable.{ListBuffer, LinkedHashMap}

object JavaParsers {

Expand Down Expand Up @@ -96,8 +97,12 @@ object JavaParsers {
def javaLangDot(name: Name): Tree =
Select(javaDot(nme.lang), name)

/** Tree representing `java.lang.Object` */
def javaLangObject(): Tree = javaLangDot(tpnme.Object)

/** Tree representing `java.lang.Record` */
def javaLangRecord(): Tree = javaLangDot(tpnme.Record)

def arrayOf(tpt: Tree): AppliedTypeTree =
AppliedTypeTree(scalaDot(tpnme.Array), List(tpt))

Expand Down Expand Up @@ -555,6 +560,14 @@ object JavaParsers {

def definesInterface(token: Int): Boolean = token == INTERFACE || token == AT

/** If the next token is the identifier "record", convert it into the RECORD token.
* This makes it easier to handle records in various parts of the code,
* in particular when a `parentToken` is passed to some functions.
*/
def adaptRecordIdentifier(): Unit =
if in.token == IDENTIFIER && in.name == jnme.RECORDid then
in.token = RECORD

def termDecl(start: Offset, mods: Modifiers, parentToken: Int, parentTParams: List[TypeDef]): List[Tree] = {
val inInterface = definesInterface(parentToken)
val tparams = if (in.token == LT) typeParams(Flags.JavaDefined | Flags.Param) else List()
Expand All @@ -581,6 +594,16 @@ object JavaParsers {
TypeTree(), methodBody()).withMods(mods)
}
}
} else if (in.token == LBRACE && rtptName != nme.EMPTY && parentToken == RECORD) {
/*
record RecordName(T param1, ...) {
RecordName { // <- here
// methodBody
}
}
*/
methodBody()
Nil
}
else {
var mods1 = mods
Expand Down Expand Up @@ -717,12 +740,11 @@ object JavaParsers {
ValDef(name, tpt2, if (mods.is(Flags.Param)) EmptyTree else unimplementedExpr).withMods(mods1)
}

def memberDecl(start: Offset, mods: Modifiers, parentToken: Int, parentTParams: List[TypeDef]): List[Tree] = in.token match {
case CLASS | ENUM | INTERFACE | AT =>
typeDecl(start, if (definesInterface(parentToken)) mods | Flags.JavaStatic else mods)
def memberDecl(start: Offset, mods: Modifiers, parentToken: Int, parentTParams: List[TypeDef]): List[Tree] = in.token match
case CLASS | ENUM | RECORD | INTERFACE | AT =>
typeDecl(start, if definesInterface(parentToken) then mods | Flags.JavaStatic else mods)
case _ =>
termDecl(start, mods, parentToken, parentTParams)
}

def makeCompanionObject(cdef: TypeDef, statics: List[Tree]): Tree =
atSpan(cdef.span) {
Expand Down Expand Up @@ -804,6 +826,51 @@ object JavaParsers {
addCompanionObject(statics, cls)
}

def recordDecl(start: Offset, mods: Modifiers): List[Tree] =
accept(RECORD)
val nameOffset = in.offset
val name = identForType()
val tparams = typeParams()
val header = formalParams()
val superclass = javaLangRecord() // records always extend java.lang.Record
val interfaces = interfacesOpt() // records may implement interfaces
val (statics, body) = typeBody(RECORD, name, tparams)

// We need to generate accessors for every param, if no method with the same name is already defined

var fieldsByName = header.map(v => (v.name, (v.tpt, v.mods.annotations))).to(LinkedHashMap)

for case DefDef(name, paramss, _, _) <- body
if paramss.isEmpty && fieldsByName.contains(name)
do
fieldsByName -= name
end for

val accessors =
(for (name, (tpt, annots)) <- fieldsByName yield
DefDef(name, Nil, tpt, unimplementedExpr)
.withMods(Modifiers(Flags.JavaDefined | Flags.Method | Flags.Synthetic))
).toList

// generate the canonical constructor
val canonicalConstructor =
DefDef(nme.CONSTRUCTOR, joinParams(tparams, List(header)), TypeTree(), EmptyTree)
.withMods(Modifiers(Flags.JavaDefined | Flags.Synthetic, mods.privateWithin))

// return the trees
val recordTypeDef = atSpan(start, nameOffset) {
TypeDef(name,
makeTemplate(
parents = superclass :: interfaces,
stats = canonicalConstructor :: accessors ::: body,
tparams = tparams,
true
)
).withMods(mods)
}
addCompanionObject(statics, recordTypeDef)
end recordDecl

def interfaceDecl(start: Offset, mods: Modifiers): List[Tree] = {
accept(INTERFACE)
val nameOffset = in.offset
Expand Down Expand Up @@ -846,7 +913,8 @@ object JavaParsers {
else if (in.token == SEMI)
in.nextToken()
else {
if (in.token == ENUM || definesInterface(in.token)) mods |= Flags.JavaStatic
adaptRecordIdentifier()
if (in.token == ENUM || in.token == RECORD || definesInterface(in.token)) mods |= Flags.JavaStatic
val decls = memberDecl(start, mods, parentToken, parentTParams)
(if (mods.is(Flags.JavaStatic) || inInterface && !(decls exists (_.isInstanceOf[DefDef])))
statics
Expand Down Expand Up @@ -947,13 +1015,13 @@ object JavaParsers {
}
}

def typeDecl(start: Offset, mods: Modifiers): List[Tree] = in.token match {
def typeDecl(start: Offset, mods: Modifiers): List[Tree] = in.token match
case ENUM => enumDecl(start, mods)
case INTERFACE => interfaceDecl(start, mods)
case AT => annotationDecl(start, mods)
case CLASS => classDecl(start, mods)
case RECORD => recordDecl(start, mods)
case _ => in.nextToken(); syntaxError(em"illegal start of type declaration", skipIt = true); List(errorTypeTree)
}

def tryConstant: Option[Constant] = {
val negate = in.token match {
Expand Down Expand Up @@ -1004,6 +1072,7 @@ object JavaParsers {
if (in.token != EOF) {
val start = in.offset
val mods = modifiers(inInterface = false)
adaptRecordIdentifier() // needed for typeDecl
buf ++= typeDecl(start, mods)
}
}
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/parsing/JavaTokens.scala
Expand Up @@ -41,6 +41,9 @@ object JavaTokens extends TokensCommon {
inline val SWITCH = 133; enter(SWITCH, "switch")
inline val ASSERT = 134; enter(ASSERT, "assert")

/** contextual keywords (turned into keywords in certain conditions, see JLS 3.9 of Java 9+) */
inline val RECORD = 135; enter(RECORD, "record")

/** special symbols */
inline val EQEQ = 140
inline val BANGEQ = 141
Expand Down
18 changes: 13 additions & 5 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Expand Up @@ -862,7 +862,6 @@ class Namer { typer: Typer =>
* with a user-defined method in the same scope with a matching type.
*/
private def invalidateIfClashingSynthetic(denot: SymDenotation): Unit =

def isCaseClassOrCompanion(owner: Symbol) =
owner.isClass && {
if (owner.is(Module)) owner.linkedClass.is(CaseClass)
Expand All @@ -879,10 +878,19 @@ class Namer { typer: Typer =>
!sd.symbol.is(Deferred) && sd.matches(denot)))

val isClashingSynthetic =
denot.is(Synthetic, butNot = ConstructorProxy)
&& desugar.isRetractableCaseClassMethodName(denot.name)
&& isCaseClassOrCompanion(denot.owner)
&& (definesMember || inheritsConcreteMember)
denot.is(Synthetic, butNot = ConstructorProxy) &&
(
(desugar.isRetractableCaseClassMethodName(denot.name)
&& isCaseClassOrCompanion(denot.owner)
&& (definesMember || inheritsConcreteMember)
)
||
// remove synthetic constructor of a java Record if it clashes with a non-synthetic constructor
(denot.isConstructor
&& denot.owner.is(JavaDefined) && denot.owner.derivesFrom(defn.JavaRecordClass)
&& denot.owner.unforcedDecls.lookupAll(denot.name).exists(c => c != denot.symbol && c.info.matches(denot.info))
)
)

if isClashingSynthetic then
typr.println(i"invalidating clashing $denot in ${denot.owner}")
Expand Down
14 changes: 10 additions & 4 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Expand Up @@ -2441,11 +2441,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
}

def typedDefDef(ddef: untpd.DefDef, sym: Symbol)(using Context): Tree = {
if (!sym.info.exists) { // it's a discarded synthetic case class method, drop it
assert(sym.is(Synthetic) && desugar.isRetractableCaseClassMethodName(sym.name))
def canBeInvalidated(sym: Symbol): Boolean =
sym.is(Synthetic)
&& (desugar.isRetractableCaseClassMethodName(sym.name) ||
(sym.isConstructor && sym.owner.derivesFrom(defn.JavaRecordClass)))

if !sym.info.exists then
// it's a discarded method (synthetic case class method or synthetic java record constructor), drop it
assert(canBeInvalidated(sym))
sym.owner.info.decls.openForMutations.unlink(sym)
return EmptyTree
}

// TODO: - Remove this when `scala.language.experimental.erasedDefinitions` is no longer experimental.
// - Modify signature to `erased def erasedValue[T]: T`
if sym.eq(defn.Compiletime_erasedValue) then
Expand Down Expand Up @@ -3598,7 +3604,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
adapt(tree, pt, ctx.typerState.ownedVars)

private def adapt1(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = {
assert(pt.exists && !pt.isInstanceOf[ExprType] || ctx.reporter.errorsReported)
assert(pt.exists && !pt.isInstanceOf[ExprType] || ctx.reporter.errorsReported, i"tree: $tree, pt: $pt")
def methodStr = err.refStr(methPart(tree).tpe)

def readapt(tree: Tree)(using Context) = adapt(tree, pt, locked)
Expand Down
11 changes: 8 additions & 3 deletions compiler/test/dotty/tools/dotc/CompilationTests.scala
Expand Up @@ -29,7 +29,7 @@ class CompilationTests {

@Test def pos: Unit = {
implicit val testGroup: TestGroup = TestGroup("compilePos")
aggregateTests(
var tests = List(
compileFile("tests/pos/nullarify.scala", defaultOptions.and("-Ycheck:nullarify")),
compileFile("tests/pos-special/utf8encoded.scala", explicitUTF8),
compileFile("tests/pos-special/utf16encoded.scala", explicitUTF16),
Expand Down Expand Up @@ -65,8 +65,13 @@ class CompilationTests {
compileFile("tests/pos-special/extend-java-enum.scala", defaultOptions.and("-source", "3.0-migration")),
compileFile("tests/pos-custom-args/help.scala", defaultOptions.and("-help", "-V", "-W", "-X", "-Y")),
compileFile("tests/pos-custom-args/i13044.scala", defaultOptions.and("-Xmax-inlines:33")),
compileFile("tests/pos-custom-args/jdk-8-app.scala", defaultOptions.and("-release:8")),
).checkCompile()
compileFile("tests/pos-custom-args/jdk-8-app.scala", defaultOptions.and("-release:8"))
)

if scala.util.Properties.isJavaAtLeast("16") then
tests ::= compileFilesInDir("tests/pos-java16+", defaultOptions.and("-Ysafe-init"))

aggregateTests(tests*).checkCompile()
}

@Test def rewrites: Unit = {
Expand Down
43 changes: 43 additions & 0 deletions tests/pos-java16+/java-records/FromScala.scala
@@ -0,0 +1,43 @@
object C:
def useR1: Unit =
// constructor signature
val r = R1(123, "hello")

// accessors
val i: Int = r.i
val s: String = r.s

// methods
val iRes: Int = r.getInt()
val sRes: String = r.getString()

// supertype
val record: java.lang.Record = r

def useR2: Unit =
// constructor signature
val r2 = R2.R(123, "hello")

// accessors signature
val i: Int = r2.i
val s: String = r2.s

// method
val i2: Int = r2.getInt

// supertype
val isIntLike: IntLike = r2
val isRecord: java.lang.Record = r2

def useR3 =
// constructor signature
val r3 = R3(123, 42L, "hi")
new R3("hi", 123)
// accessors signature
val i: Int = r3.i
val l: Long = r3.l
val s: String = r3.s
// method
val l2: Long = r3.l(43L, 44L)
// supertype
val isRecord: java.lang.Record = r3
2 changes: 2 additions & 0 deletions tests/pos-java16+/java-records/IntLike.scala
@@ -0,0 +1,2 @@
trait IntLike:
def getInt: Int
9 changes: 9 additions & 0 deletions tests/pos-java16+/java-records/R1.java
@@ -0,0 +1,9 @@
public record R1(int i, String s) {
public String getString() {
return s + i;
}

public int getInt() {
return 0;
}
}
13 changes: 13 additions & 0 deletions tests/pos-java16+/java-records/R2.java
@@ -0,0 +1,13 @@
public class R2 {
final record R(int i, String s) implements IntLike {
public int getInt() {
return i;
}

// Canonical constructor
public R(int i, java.lang.String s) {
this.i = i;
this.s = s.intern();
}
}
}
22 changes: 22 additions & 0 deletions tests/pos-java16+/java-records/R3.java
@@ -0,0 +1,22 @@
public record R3(int i, long l, String s) {

// User-specified accessor
public int i() {
return i + 1; // evil >:)
}

// Not an accessor - too many parameters
public long l(long a1, long a2) {
return a1 + a2;
}

// Secondary constructor
public R3(String s, int i) {
this(i, 42L, s);
}

// Compact constructor
public R3 {
s = s.intern();
}
}

0 comments on commit 4f3632f

Please sign in to comment.