Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: index type hierarchy in java files #6189

Merged
merged 1 commit into from Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions metals-bench/src/main/scala/bench/MetalsBench.scala
Expand Up @@ -11,6 +11,7 @@ import scala.meta.internal.metals.JdkSources
import scala.meta.internal.metals.ReportContext
import scala.meta.internal.metals.logging.MetalsLogger
import scala.meta.internal.mtags.JavaMtags
import scala.meta.internal.mtags.JavaToplevelMtags
import scala.meta.internal.mtags.Mtags
import scala.meta.internal.mtags.OnDemandSymbolIndex
import scala.meta.internal.mtags.ScalaMtags
Expand Down Expand Up @@ -179,6 +180,14 @@ class MetalsBench {
}
}

@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
def toplevelJavaMtags(): Unit = {
javaDependencySources.inputs.foreach { input =>
new JavaToplevelMtags(input, includeInnerClasses = true).index()
}
}

@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
def indexSources(): Unit = {
Expand Down
2 changes: 2 additions & 0 deletions metals/src/main/resources/db/migration/V6__Delete_indices.sql
@@ -0,0 +1,2 @@
-- indexing type hierarchy has changed, so we want to reindex
delete from indexed_jar;
147 changes: 121 additions & 26 deletions mtags/src/main/scala/scala/meta/internal/mtags/JavaToplevelMtags.scala
Expand Up @@ -10,20 +10,31 @@ import scala.meta.internal.semanticdb.SymbolInformation
import scala.meta.internal.tokenizers.Chars._
import scala.meta.internal.tokenizers.Reporter

class JavaToplevelMtags(val input: Input.VirtualFile) extends MtagsIndexer {
class JavaToplevelMtags(
val input: Input.VirtualFile,
includeInnerClasses: Boolean
) extends MtagsIndexer {

import JavaToplevelMtags._

val reporter: Reporter = Reporter(input)
val reader: CharArrayReader =
new CharArrayReader(input, dialects.Scala213, reporter)

override def overrides(): List[(String, List[OverriddenSymbol])] =
overridden.result

private val overridden = List.newBuilder[(String, List[OverriddenSymbol])]

private def addOverridden(symbols: List[OverriddenSymbol]) =
overridden += ((currentOwner, symbols))

override def language: Language = Language.JAVA

override def indexRoot(): Unit = {
if (!input.path.endsWith("module-info.java")) {
reader.nextRawChar()
loop
loop(None)
}
}

Expand All @@ -35,29 +46,90 @@ class JavaToplevelMtags(val input: Input.VirtualFile) extends MtagsIndexer {
}
}

private def loop: Unit = {
@tailrec
private def loop(region: Option[Region]): Unit = {
val token = fetchToken
token match {
case Token.EOF =>
case Token.Package =>
val paths = readPaths
paths.foreach { path => pkg(path.value, path.pos) }
loop
loop(region)
case Token.Class | Token.Interface | _: Token.Enum | _: Token.Record =>
fetchToken match {
case Token.Word(v, pos) =>
val kind = token match {
case Token.Interface => SymbolInformation.Kind.INTERFACE
case _ => SymbolInformation.Kind.CLASS
}
withOwner(currentOwner)(tpe(v, pos, kind, 0))
skipBody
loop
val previousOwner = currentOwner
tpe(v, pos, kind, 0)
if (includeInnerClasses) {
collectTypeHierarchyInformation
loop(Some(Region(region, currentOwner, lBraceCount = 1)))
} else {
skipBody
currentOwner = previousOwner
loop(region)
}
case Token.LBrace =>
loop(region.map(_.lBrace()))
case Token.RBrace =>
val newRegion = region.flatMap(_.rBrace())
newRegion.foreach(reg => currentOwner = reg.owner)
loop(newRegion)
case _ =>
loop
loop(region)
}
case Token.LBrace =>
loop(region.map(_.lBrace()))
case Token.RBrace =>
val newRegion = region.flatMap(_.rBrace())
newRegion.foreach(reg => currentOwner = reg.owner)
loop(newRegion)
case _ =>
loop(region)
}
}

private def collectTypeHierarchyInformation: Unit = {
val implementsOrExtends = List.newBuilder[String]
@tailrec
def skipUntilOptImplementsOrExtends: Token = {
fetchToken match {
case t @ (Token.Implements | Token.Extends) => t
case Token.EOF => Token.EOF
case Token.LBrace => Token.LBrace
case _ => skipUntilOptImplementsOrExtends
}
}

@tailrec
def collectHierarchy: Unit = {
fetchToken match {
case Token.Word(v, _) =>
// emit here
implementsOrExtends += v
collectHierarchy
case Token.LBrace =>
case Token.LParen =>
skipBalanced(Token.LParen, Token.RParen)
collectHierarchy
case Token.LessThan =>
skipBalanced(Token.LessThan, Token.GreaterThan)
collectHierarchy
case Token.EOF =>
case _ => collectHierarchy
}
}

skipUntilOptImplementsOrExtends match {
case Token.Implements | Token.Extends =>
collectHierarchy
addOverridden(
implementsOrExtends.result.distinct.map(UnresolvedOverriddenSymbol(_))
)
case _ =>
loop
}
}

Expand Down Expand Up @@ -103,6 +175,8 @@ class JavaToplevelMtags(val input: Input.VirtualFile) extends MtagsIndexer {
case "interface" => Token.Interface
case "record" => Token.Record(pos)
case "enum" => Token.Enum(pos)
case "extends" => Token.Extends
case "implements" => Token.Implements
case ident =>
Token.Word(ident, pos)
}
Expand All @@ -113,8 +187,8 @@ class JavaToplevelMtags(val input: Input.VirtualFile) extends MtagsIndexer {
def parseToken: (Token, Boolean) = {
val first = reader.ch
first match {
case ',' | '<' | '>' | '&' | '|' | '!' | '=' | '+' | '-' | '*' | '@' |
':' | '?' | '%' | '^' | '~' =>
case ',' | '&' | '|' | '!' | '=' | '+' | '-' | '*' | '@' | ':' | '?' |
'%' | '^' | '~' =>
(Token.SpecialSym, false)
case SU => (Token.EOF, false)
case '.' => (Token.Dot, false)
Expand All @@ -125,6 +199,8 @@ class JavaToplevelMtags(val input: Input.VirtualFile) extends MtagsIndexer {
case ')' => (Token.RParen, false)
case '[' => (Token.LBracket, false)
case ']' => (Token.RBracket, false)
case '<' => (Token.LessThan, false)
case '>' => (Token.GreaterThan, false)
case '"' => (quotedLiteral('"'), false)
case '\'' => (quotedLiteral('\''), false)
case '/' =>
Expand Down Expand Up @@ -190,22 +266,26 @@ class JavaToplevelMtags(val input: Input.VirtualFile) extends MtagsIndexer {
skipToFirstBrace
}

@tailrec
def skipToRbrace(open: Int): Unit = {
fetchToken match {
case Token.RBrace if open == 1 => ()
case Token.RBrace =>
skipToRbrace(open - 1)
case Token.LBrace =>
skipToRbrace(open + 1)
case Token.EOF => ()
case _ =>
skipToRbrace(open)
}
}

skipToFirstBrace
skipToRbrace(1)
skipBalanced(Token.LBrace, Token.RBrace)
}

@tailrec
private def skipBalanced(
openingToken: Token,
closingToken: Token,
open: Int = 1
): Unit = {
fetchToken match {
case t if t == closingToken && open == 1 => ()
case t if t == closingToken =>
skipBalanced(openingToken, closingToken, open - 1)
case t if t == openingToken =>
skipBalanced(openingToken, closingToken, open + 1)
case Token.EOF => ()
case _ =>
skipBalanced(openingToken, closingToken, open)
}
}

private def skipLine: Unit =
Expand Down Expand Up @@ -260,12 +340,16 @@ object JavaToplevelMtags {
case class Record(pos: Position) extends WithPos {
val value: String = "record"
}
case object Implements extends Token
case object Extends extends Token
case object RBrace extends Token
case object LBrace extends Token
case object RParen extends Token
case object LParen extends Token
case object RBracket extends Token
case object LBracket extends Token
case object LessThan extends Token
case object GreaterThan extends Token
case object Semicolon extends Token
// any allowed symbol like `=` , `-` and others
case object SpecialSym extends Token
Expand All @@ -277,4 +361,15 @@ object JavaToplevelMtags {
}

}

case class Region(
previousRegion: Option[Region],
owner: String,
lBraceCount: Int
) {
def lBrace(): Region = Region(previousRegion, owner, lBraceCount + 1)
def rBrace(): Option[Region] =
if (lBraceCount == 1) previousRegion
else Some(Region(previousRegion, owner, lBraceCount - 1))
}
}
4 changes: 2 additions & 2 deletions mtags/src/main/scala/scala/meta/internal/mtags/Mtags.scala
Expand Up @@ -31,7 +31,7 @@ final class Mtags(implicit rc: ReportContext) {
if (language.isJava || language.isScala) {
val mtags =
if (language.isJava)
new JavaToplevelMtags(input)
new JavaToplevelMtags(input, includeInnerClasses = false)
else
new ScalaToplevelMtags(
input,
Expand Down Expand Up @@ -59,7 +59,7 @@ final class Mtags(implicit rc: ReportContext) {
if (language.isJava || language.isScala) {
val mtags =
if (language.isJava)
new JavaToplevelMtags(input)
new JavaToplevelMtags(input, includeInnerClasses = true)
else
new ScalaToplevelMtags(
input,
Expand Down
Expand Up @@ -114,7 +114,8 @@ final class OnDemandSymbolIndex(
source,
None, {
indexedSources += 1
getOrCreateBucket(dialect).addSourceFile(source, sourceDirectory)
getOrCreateBucket(dialect)
.addSourceFile(source, sourceDirectory, isJava = false)
}
)

Expand Down
Expand Up @@ -52,7 +52,7 @@ class SymbolIndexBucket(
if (sourceJars.addEntry(dir.toNIO)) {
dir.listRecursive.toList.flatMap {
case source if source.isScala =>
addSourceFile(source, Some(dir))
addSourceFile(source, Some(dir), isJava = false)
case _ =>
None
}
Expand All @@ -67,13 +67,9 @@ class SymbolIndexBucket(
try {
root.listRecursive.toList.flatMap {
case source if source.isScala =>
addSourceFile(source, None)
addSourceFile(source, None, isJava = false)
case source if source.isJava =>
addJavaSourceFile(source) match {
case Nil => None
case topLevels =>
Some(IndexingResult(source, topLevels, overrides = Nil))
}
addSourceFile(source, None, isJava = true)
case _ =>
None
}
Expand All @@ -100,39 +96,13 @@ class SymbolIndexBucket(
}
}

/* Sometimes source jars have additional nested directories,
* in that case java toplevel is not "trivial".
* See: https://github.com/scalameta/metals/issues/3815
*/
def addJavaSourceFile(source: AbsolutePath): List[String] = {
new JavaToplevelMtags(source.toInput).readPackage match {
case Nil => Nil
case packageParts =>
val className = source.filename.stripSuffix(".java")
val symbol = packageParts.mkString("", "/", s"/$className#")
if (
isTrivialToplevelSymbol(
source.toURI.toString,
symbol,
extension = "java"
)
) Nil
else {
toplevels.updateWith(symbol) {
case Some(acc) => Some(acc + source)
case None => Some(Set(source))
}
List(symbol)
}
}
}

def addSourceFile(
source: AbsolutePath,
sourceDirectory: Option[AbsolutePath]
sourceDirectory: Option[AbsolutePath],
isJava: Boolean
): Option[IndexingResult] = {
val IndexingResult(path, topLevels, overrides) =
indexSource(source, dialect, sourceDirectory)
indexSource(source, dialect, sourceDirectory, isJava)
topLevels.foreach { symbol =>
toplevels.updateWith(symbol) {
case Some(acc) => Some(acc + source)
Expand All @@ -145,7 +115,8 @@ class SymbolIndexBucket(
private def indexSource(
source: AbsolutePath,
dialect: Dialect,
sourceDirectory: Option[AbsolutePath]
sourceDirectory: Option[AbsolutePath],
isJava: Boolean
): IndexingResult = {
val uri = source.toIdeallyRelativeURI(sourceDirectory)
val (doc, overrides) = mtags.indexWithOverrides(source, dialect)
Expand All @@ -155,8 +126,15 @@ class SymbolIndexBucket(
.map(_.symbol)
val topLevels =
if (source.isAmmoniteScript) sourceTopLevels.toList
else
sourceTopLevels.filter(sym => !isTrivialToplevelSymbol(uri, sym)).toList
else if (isJava) {
sourceTopLevels.toList.headOption
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any idea if the indexing is becoming much slower?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark                      Mode  Cnt  Score   Error  Units
MetalsBench.javaMtagsPackage     ss   10  0.168 ± 0.014   s/op
MetalsBench.toplevelJavaMtags    ss   10  2.843 ± 0.112   s/op

It does get slower but it doesn't seem to be too bad. In practice for e.g. Metals there is no visible slowdown.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok cool, let's merge it then 🚀

Copy link
Contributor Author

@kasiaMarek kasiaMarek Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for e.g. Metals there is no visible slowdown

Actually, I was wrong (I looked at the whole indexed workspace time, which can be very misleading on laptop doing million things in the background).

The actual slowdown:
samples before:

time: indexed library sources in 7.45s
time: indexed library sources in 9.97s
time: indexed library sources in 11s

samples after:

time: indexed library sources in 26s
time: indexed library sources in 22s
time: indexed library sources in 19s

So there is an over 200% slowdown as visible on the CI.

Indexing Java top level doesn't seem to be slower than for Scala but we didn't index Java files almost at all before.

MetalsBench.typeHierarchyIndex     ss   10  0.404 ± 0.008   s/op (for Scala lines: 383135)
~ 1.054 for 1m lines
MetalsBench.toplevelJavaMtags      ss   10  2.843 ± 0.112   s/op (for Java lines: 5170931)
~ 0.54  for 1m lines

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the biggest slowdown I think is because we index JDK multiple times on one machine, which doesn't seem necessary. Should we have a separate database for JDKs then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does seem reasonable to index JDK just once instead of doing it for every project.

.filter(sym => !isTrivialToplevelSymbol(uri, sym, "java"))
.toList
} else {
sourceTopLevels
.filter(sym => !isTrivialToplevelSymbol(uri, sym, "scala"))
.toList
}
IndexingResult(source, topLevels, overrides)
}

Expand Down