Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: go to implementations for multi-module projects #6211

Merged
merged 2 commits into from Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -15,6 +15,7 @@ import scala.meta.internal.metals.BuildTargets
import scala.meta.internal.metals.Compilers
import scala.meta.internal.metals.DefinitionProvider
import scala.meta.internal.metals.MetalsEnrichments._
import scala.meta.internal.metals.Report
import scala.meta.internal.metals.ReportContext
import scala.meta.internal.metals.ScalaVersionSelector
import scala.meta.internal.metals.ScalaVersions
Expand All @@ -25,9 +26,7 @@ import scala.meta.internal.mtags.Mtags
import scala.meta.internal.mtags.OverriddenSymbol
import scala.meta.internal.mtags.ResolvedOverriddenSymbol
import scala.meta.internal.mtags.Semanticdbs
import scala.meta.internal.mtags.SymbolDefinition
import scala.meta.internal.mtags.UnresolvedOverriddenSymbol
import scala.meta.internal.mtags.{Symbol => MSymbol}
import scala.meta.internal.parsing.Trees
import scala.meta.internal.pc.PcSymbolInformation
import scala.meta.internal.search.SymbolHierarchyOps._
Expand Down Expand Up @@ -155,23 +154,42 @@ final class ImplementationProvider(
)
.toIterable
} yield {
// 1. Search locally for symbol
// 2. Search inside workspace
// 3. Search classpath via GlobalSymbolTable
val sym = symbolOccurrence.symbol
val dealiased =
if (sym.desc.isType) {
symbolInfo(currentDocument, source, sym).map(
_.map(_.dealiasedSymbol).getOrElse(sym)
)
} else Future.successful(sym)

dealiased.flatMap { dealisedSymbol =>
val isWorkspaceSymbol =
(source.isWorkspaceSource(workspace) &&
currentDocument.definesSymbol(dealisedSymbol)) ||
findSymbolDefinition(dealisedSymbol).exists(
_.path.isWorkspaceSource(workspace)
val currentContainsDefinition =
currentDocument.definesSymbol(dealisedSymbol)
val sourceFiles: Set[AbsolutePath] =
if (currentContainsDefinition) Set(source)
else
definitionProvider
.fromSymbol(dealisedSymbol, Some(source))
.asScala
.map(_.getUri().toAbsolutePath)
.toSet

if (sourceFiles.isEmpty) {
rc.unsanitized.create(
Report(
"missing-definition",
s"""|Missing definition symbol for:
|$dealisedSymbol
|""".stripMargin,
s"missing def: $dealisedSymbol",
Some(source.toURI.toString()),
)
)
}

val isWorkspaceSymbol =
(currentContainsDefinition && source.isWorkspaceSource(workspace)) ||
sourceFiles.forall(_.isWorkspaceSource(workspace))

val workspaceInheritanceContext: InheritanceContext =
InheritanceContext.fromDefinitions(
Expand All @@ -193,6 +211,7 @@ final class ImplementationProvider(
dealisedSymbol,
currentDocument,
source,
sourceFiles,
inheritanceContext,
)
}
Expand All @@ -207,6 +226,7 @@ final class ImplementationProvider(
dealiased: String,
textDocument: TextDocument,
source: AbsolutePath,
definitionFiles: Set[AbsolutePath],
inheritanceContext: InheritanceContext,
): Future[Seq[Location]] = {

Expand Down Expand Up @@ -276,13 +296,20 @@ final class ImplementationProvider(
locationsByFile: Map[Path, Set[ClassLocation]],
parentSymbol: PcSymbolInformation,
classSymbol: String,
buildTarget: BuildTargetIdentifier,
definitionBuildTargets: Set[BuildTargetIdentifier],
) = Future.sequence({
def allDependencyBuildTargets(implPath: AbsolutePath) = {
val targets = buildTargets.inverseSourcesAll(implPath)
buildTargets.buildTargetTransitiveDependencies(targets).toSet ++ targets
}

for {
file <- files
locations = locationsByFile(file)
implPath = AbsolutePath(file)
if (buildTargets.belongsToBuildTarget(buildTarget, implPath))
if (definitionBuildTargets.isEmpty || allDependencyBuildTargets(
implPath
).exists(definitionBuildTargets(_)))
implDocument <- findSemanticdb(implPath).toList
} yield {
for {
Expand Down Expand Up @@ -331,7 +358,6 @@ final class ImplementationProvider(
(for {
symbolInfo <- optSymbolInfo
symbolClass <- classFromSymbol(symbolInfo)
target <- buildTargets.inverseSources(source)
} yield {
for {
locationsByFile <- findImplementation(
Expand All @@ -350,7 +376,9 @@ final class ImplementationProvider(
locationsByFile,
symbolInfo,
symbolClass,
target,
definitionFiles.flatMap(
buildTargets.inverseSourcesAll(_).toSet
),
)
)
)
Expand Down Expand Up @@ -425,10 +453,6 @@ final class ImplementationProvider(
})
}

private def findSymbolDefinition(symbol: String): Option[SymbolDefinition] = {
index.definition(MSymbol(symbol))
}

private def classFromSymbol(info: PcSymbolInformation): Option[String] =
if (classLikeKinds(info.kind)) Some(info.dealiasedSymbol)
else info.classOwner
Expand Down
Expand Up @@ -51,8 +51,8 @@ class GlobalInheritanceContext(
val resolveGlobal =
implementationsInDependencySources
.getOrElse(shortName, Set.empty)
.collect { case loc @ ClassLocation(sym, _) =>
compilers.info(source, sym).map {
.collect { case loc @ ClassLocation(sym, Some(filePath)) =>
compilers.info(AbsolutePath(filePath), sym).map {
case Some(symInfo) if symInfo.parents.contains(symbol) => Some(loc)
case Some(symInfo)
if symInfo.dealiasedSymbol == symbol && symInfo.symbol != symbol =>
Expand Down
Expand Up @@ -171,10 +171,15 @@ final class BuildTargets private (

def buildTargetTransitiveDependencies(
id: BuildTargetIdentifier
): Iterable[BuildTargetIdentifier] =
buildTargetTransitiveDependencies(List(id))

def buildTargetTransitiveDependencies(
ids: List[BuildTargetIdentifier]
): Iterable[BuildTargetIdentifier] = {
val isVisited = mutable.Set.empty[BuildTargetIdentifier]
val toVisit = new java.util.ArrayDeque[BuildTargetIdentifier]
toVisit.add(id)
ids.foreach(toVisit.add(_))
while (!toVisit.isEmpty) {
val next = toVisit.pop()
if (!isVisited(next)) {
Expand Down Expand Up @@ -382,15 +387,6 @@ final class BuildTargets private (
}
}

def belongsToBuildTarget(
target: BuildTargetIdentifier,
path: AbsolutePath,
): Boolean = {
val possibleBuildTargets =
buildTargetTransitiveDependencies(target).toSet + target
inverseSourcesAll(path).exists(possibleBuildTargets(_))
}

def inferBuildTarget(
source: AbsolutePath
): Option[BuildTargetIdentifier] =
Expand Down
55 changes: 55 additions & 0 deletions tests/unit/src/test/scala/tests/ImplementationLspSuite.scala
@@ -1,5 +1,9 @@
package tests

import scala.concurrent.Future

import scala.meta.internal.metals.MetalsEnrichments._

class ImplementationLspSuite extends BaseImplementationSuite("implementation") {

check(
Expand Down Expand Up @@ -659,6 +663,57 @@ class ImplementationLspSuite extends BaseImplementationSuite("implementation") {
|""".stripMargin,
)

test("multi-module") {
val fileName = "a/src/main/scala/com/example/foo/Foo.scala"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we instead add a parameter to check method to have a custom metals.json?

val fileContents =
"""|package com.example.foo
|trait F@@oo {
| def transform(input: Int): Int
|}
|""".stripMargin
cleanWorkspace()
for {
_ <- initialize(
s"""/metals.json
|{
| "a":{ },
| "b":{
| "dependsOn": ["a"]
| }
|}
|/$fileName
|${fileContents.replaceAll("@@", "")}
|/b/src/main/scala/com/example/bar/Bar.scala
|package com.example.bar
|
|import com.example.foo.Foo
|
|class Bar extends Foo {
| override def transform(input: Int): Int = input * 2
|}
""".stripMargin
)
_ <- server.didOpen("b/src/main/scala/com/example/bar/Bar.scala")
_ <- server.didOpen(fileName)
_ = assertNoDiagnostics()
locations <- server.implementation(fileName, fileContents)
definitions <-
Future.sequence(
locations.map(location =>
server.server.definitionResult(
location.toTextDocumentPositionParams
)
)
)
symbols = definitions.map(_.symbol).sorted
_ = assertNoDiff(
symbols.mkString("\n"),
"com/example/bar/Bar#",
)
_ <- server.shutdown()
} yield ()
}

override protected def libraryDependencies: List[String] =
List("org.scalatest::scalatest:3.2.16", "io.circe::circe-generic:0.12.0")

Expand Down