Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harden REPL in presence of values that fail to initialize #14702

Merged
merged 2 commits into from Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 13 additions & 7 deletions compiler/src/dotty/tools/repl/Rendering.scala
Expand Up @@ -129,25 +129,31 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None) {
infoDiagnostic(d.symbol.showUser, d)

/** Render value definition result */
def renderVal(d: Denotation)(using Context): Option[Diagnostic] =
def renderVal(d: Denotation)(using Context): Either[InvocationTargetException, Option[Diagnostic]] =
val dcl = d.symbol.showUser
def msg(s: String) = infoDiagnostic(s, d)
try
if (d.symbol.is(Flags.Lazy)) Some(msg(dcl))
else valueOf(d.symbol).map(value => msg(s"$dcl = $value"))
catch case e: InvocationTargetException => Some(msg(renderError(e, d)))
Right(
if d.symbol.is(Flags.Lazy) then Some(msg(dcl))
else valueOf(d.symbol).map(value => msg(s"$dcl = $value"))
)
catch case e: InvocationTargetException => Left(e)
end renderVal

/** Force module initialization in the absence of members. */
def forceModule(sym: Symbol)(using Context): Seq[Diagnostic] =
import scala.util.control.NonFatal
def load() =
val objectName = sym.fullName.encode.toString
Class.forName(objectName, true, classLoader())
Nil
try load() catch case e: ExceptionInInitializerError => List(infoDiagnostic(renderError(e, sym.denot), sym.denot))
try load()
catch
case e: ExceptionInInitializerError => List(renderError(e, sym.denot))
case NonFatal(e) => List(renderError(InvocationTargetException(e), sym.denot))

/** Render the stack trace of the underlying exception. */
private def renderError(ite: InvocationTargetException | ExceptionInInitializerError, d: Denotation)(using Context): String =
def renderError(ite: InvocationTargetException | ExceptionInInitializerError, d: Denotation)(using Context): Diagnostic =
import dotty.tools.dotc.util.StackTraceOps._
val cause = ite.getCause match
case e: ExceptionInInitializerError => e.getCause
Expand All @@ -159,7 +165,7 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None) {
ste.getClassName.startsWith(REPL_WRAPPER_NAME_PREFIX) // d.symbol.owner.name.show is simple name
&& (ste.getMethodName == nme.STATIC_CONSTRUCTOR.show || ste.getMethodName == nme.CONSTRUCTOR.show)

cause.formatStackTracePrefix(!isWrapperInitialization(_))
infoDiagnostic(cause.formatStackTracePrefix(!isWrapperInitialization(_)), d)
end renderError

private def infoDiagnostic(msg: String, d: Denotation)(using Context): Diagnostic =
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/repl/ReplCompiler.scala
Expand Up @@ -61,7 +61,7 @@ class ReplCompiler extends Compiler {
val rootCtx = super.rootContext.fresh
.setOwner(defn.EmptyPackageClass)
.withRootImports
(1 to state.objectIndex).foldLeft(rootCtx)((ctx, id) =>
(state.validObjectIndexes).foldLeft(rootCtx)((ctx, id) =>
importPreviousRun(id)(using ctx))
}
}
Expand Down
51 changes: 38 additions & 13 deletions compiler/src/dotty/tools/repl/ReplDriver.scala
Expand Up @@ -35,6 +35,7 @@ import dotty.tools.runner.ScalaClassLoader.*
import org.jline.reader._

import scala.annotation.tailrec
import scala.collection.mutable
import scala.collection.JavaConverters._
import scala.util.Using

Expand All @@ -55,12 +56,15 @@ import scala.util.Using
* @param objectIndex the index of the next wrapper
* @param valIndex the index of next value binding for free expressions
* @param imports a map from object index to the list of user defined imports
* @param invalidObjectIndexes the set of object indexes that failed to initialize
* @param context the latest compiler context
*/
case class State(objectIndex: Int,
valIndex: Int,
imports: Map[Int, List[tpd.Import]],
context: Context)
invalidObjectIndexes: Set[Int],
context: Context):
def validObjectIndexes = (1 to objectIndex).filterNot(invalidObjectIndexes.contains(_))

/** Main REPL instance, orchestrating input, compilation and presentation */
class ReplDriver(settings: Array[String],
Expand Down Expand Up @@ -94,7 +98,7 @@ class ReplDriver(settings: Array[String],
}

/** the initial, empty state of the REPL session */
final def initialState: State = State(0, 0, Map.empty, rootCtx)
final def initialState: State = State(0, 0, Map.empty, Set.empty, rootCtx)

/** Reset state of repl to the initial state
*
Expand Down Expand Up @@ -237,7 +241,7 @@ class ReplDriver(settings: Array[String],
completions.map(_.label).distinct.map(makeCandidate)
}
.getOrElse(Nil)
end completions
end completions

private def interpret(res: ParseResult)(implicit state: State): State = {
res match {
Expand Down Expand Up @@ -353,14 +357,33 @@ class ReplDriver(settings: Array[String],
val typeAliases =
info.bounds.hi.typeMembers.filter(_.symbol.info.isTypeAlias)

val formattedMembers =
typeAliases.map(rendering.renderTypeAlias) ++
defs.map(rendering.renderMethod) ++
vals.flatMap(rendering.renderVal)

val diagnostics = if formattedMembers.isEmpty then rendering.forceModule(symbol) else formattedMembers

(state.copy(valIndex = state.valIndex - vals.count(resAndUnit)), diagnostics)
// The wrapper object may fail to initialize if the rhs of a ValDef throws.
// In that case, don't attempt to render any subsequent vals, and mark this
// wrapper object index as invalid.
var failedInit = false
val renderedVals =
val buf = mutable.ListBuffer[Diagnostic]()
for d <- vals do if !failedInit then rendering.renderVal(d) match
case Right(Some(v)) =>
buf += v
case Left(e) =>
buf += rendering.renderError(e, d)
failedInit = true
case _ =>
buf.toList

if failedInit then
// We limit the returned diagnostics here to `renderedVals`, which will contain the rendered error
// for the val which failed to initialize. Since any other defs, aliases, imports, etc. from this
// input line will be inaccessible, we avoid rendering those so as not to confuse the user.
(state.copy(invalidObjectIndexes = state.invalidObjectIndexes + state.objectIndex), renderedVals)
else
val formattedMembers =
typeAliases.map(rendering.renderTypeAlias)
++ defs.map(rendering.renderMethod)
++ renderedVals
val diagnostics = if formattedMembers.isEmpty then rendering.forceModule(symbol) else formattedMembers
(state.copy(valIndex = state.valIndex - vals.count(resAndUnit)), diagnostics)
}
else (state, Seq.empty)

Expand All @@ -378,8 +401,10 @@ class ReplDriver(settings: Array[String],
tree.symbol.info.memberClasses
.find(_.symbol.name == newestWrapper.moduleClassName)
.map { wrapperModule =>
val formattedTypeDefs = typeDefs(wrapperModule.symbol)
val (newState, formattedMembers) = extractAndFormatMembers(wrapperModule.symbol)
val formattedTypeDefs = // don't render type defs if wrapper initialization failed
if newState.invalidObjectIndexes.contains(state.objectIndex) then Seq.empty
else typeDefs(wrapperModule.symbol)
val highlighted = (formattedTypeDefs ++ formattedMembers)
.map(d => new Diagnostic(d.msg.mapMsg(SyntaxHighlighting.highlight), d.pos, d.level))
(newState, highlighted)
Expand Down Expand Up @@ -420,7 +445,7 @@ class ReplDriver(settings: Array[String],

case Imports =>
for {
objectIndex <- 1 to state.objectIndex
objectIndex <- state.validObjectIndexes
imp <- state.imports.getOrElse(objectIndex, Nil)
} out.println(imp.show(using state.context))
state
Expand Down
104 changes: 104 additions & 0 deletions compiler/test/dotty/tools/repl/ReplCompilerTests.scala
Expand Up @@ -243,6 +243,110 @@ class ReplCompilerTests extends ReplTest:
assertEquals(List("// defined class C"), lines())
}

def assertNotFoundError(id: String): Unit =
val lines = storedOutput().linesIterator
assert(lines.next().startsWith("-- [E006] Not Found Error:"))
assert(lines.drop(2).next().trim().endsWith(s"Not found: $id"))

@Test def i4416 = initially {
val state = run("val x = 1 / 0")
val all = lines()
assertEquals(2, all.length)
assert(all.head.startsWith("java.lang.ArithmeticException:"))
state
} andThen {
val state = run("def foo = x")
assertNotFoundError("x")
state
} andThen {
run("x")
assertNotFoundError("x")
}

@Test def i4416b = initially {
val state = run("val a = 1234")
val _ = storedOutput() // discard output
state
} andThen {
val state = run("val a = 1; val x = ???; val y = x")
val all = lines()
assertEquals(3, all.length)
assertEquals("scala.NotImplementedError: an implementation is missing", all.head)
state
} andThen {
val state = run("x")
assertNotFoundError("x")
state
} andThen {
val state = run("y")
assertNotFoundError("y")
state
} andThen {
run("a") // `a` should retain its original binding
assertEquals("val res0: Int = 1234", storedOutput().trim)
}

@Test def i4416_imports = initially {
run("import scala.collection.mutable")
} andThen {
val state = run("import scala.util.Try; val x = ???")
val _ = storedOutput() // discard output
state
} andThen {
run(":imports") // scala.util.Try should not be imported
assertEquals("import scala.collection.mutable", storedOutput().trim)
}

@Test def i4416_types_defs_aliases = initially {
val state =
run("""|type Foo = String
|trait Bar
|def bar: Bar = ???
|val x = ???
|""".stripMargin)
val all = lines()
assertEquals(3, all.length)
assertEquals("scala.NotImplementedError: an implementation is missing", all.head)
assert("type alias in failed wrapper should not be rendered",
!all.exists(_.startsWith("// defined alias type Foo = String")))
assert("type definitions in failed wrapper should not be rendered",
!all.exists(_.startsWith("// defined trait Bar")))
assert("defs in failed wrapper should not be rendered",
!all.exists(_.startsWith("def bar: Bar")))
state
} andThen {
val state = run("def foo: Foo = ???")
assertNotFoundError("type Foo")
state
} andThen {
val state = run("type B = Bar")
assertNotFoundError("type Bar")
state
} andThen {
run("bar")
assertNotFoundError("bar")
}

@Test def i14473 = initially {
run("""val (x,y) = if true then "hi" else (42,17)""")
val all = lines()
assertEquals(2, all.length)
assertEquals("scala.MatchError: hi (of class java.lang.String)", all.head)
}

@Test def i14701 = initially {
val state = run("val _ = ???")
val all = lines()
assertEquals(3, all.length)
assertEquals("scala.NotImplementedError: an implementation is missing", all.head)
state
} andThen {
run("val _ = assert(false)")
val all = lines()
assertEquals(3, all.length)
assertEquals("java.lang.AssertionError: assertion failed", all.head)
}

@Test def i14491 =
initially {
run("import language.experimental.fewerBraces")
Expand Down