Skip to content

Commit

Permalink
Extract munit diff module
Browse files Browse the repository at this point in the history
  • Loading branch information
majk-p committed Apr 23, 2024
1 parent 83cb747 commit f0a004c
Show file tree
Hide file tree
Showing 27 changed files with 234 additions and 209 deletions.
26 changes: 25 additions & 1 deletion build.sbt
Expand Up @@ -214,7 +214,7 @@ lazy val munit = crossProject(JSPlatform, JVMPlatform, NativePlatform)
.settings(
sharedSettings,
Compile / unmanagedSourceDirectories ++=
crossBuildingDirectories("munit", "main").value,
crossBuildingDirectories("scala-diff", "main").value,
libraryDependencies ++= List(
"org.scala-lang" % "scala-reflect" % {
if (isScala3Setting.value) scala213
Expand Down Expand Up @@ -246,6 +246,8 @@ lazy val munit = crossProject(JSPlatform, JVMPlatform, NativePlatform)
)
)
.jvmConfigure(_.dependsOn(junit))
.dependsOn(munitDiff)

lazy val munitJVM = munit.jvm
lazy val munitJS = munit.js
lazy val munitNative = munit.native
Expand All @@ -270,6 +272,28 @@ lazy val plugin = project
)
.disablePlugins(MimaPlugin)

lazy val munitDiff = crossProject(JSPlatform, JVMPlatform, NativePlatform)
.in(file("scala-diff"))
.settings(
moduleName := "scala-diff",
sharedSettings,
libraryDependencies ++= List(
"org.scala-lang" % "scala-reflect" % {
if (isScala3Setting.value) scala213
else scalaVersion.value
} % Provided
)
)
.jvmSettings(
sharedJVMSettings
)
.nativeConfigure(sharedNativeConfigure)
.nativeSettings(
sharedNativeSettings
)
.jsConfigure(sharedJSConfigure)
.jsSettings(sharedJSSettings)

lazy val tests = crossProject(JSPlatform, JVMPlatform, NativePlatform)
.dependsOn(munit)
.enablePlugins(BuildInfoPlugin)
Expand Down
Expand Up @@ -4,7 +4,7 @@

package munit.internal.junitinterface

import munit.internal.console.AnsiColors
import munit.diff.console.AnsiColors
import sbt.testing._
import munit.internal.PlatformCompat

Expand Down
30 changes: 24 additions & 6 deletions munit/shared/src/main/scala/munit/Assertions.scala
@@ -1,13 +1,16 @@
package munit

import munit.internal.console.{Lines, Printers, StackTraces}
import munit.internal.difflib.ComparisonFailExceptionHandler
import munit.internal.difflib.Diffs
import munit.internal.console.{Lines, StackTraces}
import munit.internal.console.Printers
import munit.diff.Printer
import munit.Clue
import munit.Clues
import munit.diff.EmptyPrinter

import scala.reflect.ClassTag
import scala.util.control.NonFatal
import scala.collection.mutable
import munit.internal.console.AnsiColors
import munit.diff.console.AnsiColors
import org.junit.AssumptionViolatedException
import munit.internal.MacroCompat

Expand Down Expand Up @@ -51,10 +54,10 @@ trait Assertions extends MacroCompat.CompileErrorMacro {
clue: => Any = "diff assertion failed"
)(implicit loc: Location): Unit = {
StackTraces.dropInside {
Diffs.assertNoDiff(
DiffsAssetion.assertNoDiff(
obtained,
expected,
ComparisonFailExceptionHandler.fromAssertions(this, Clues.empty),
exceptionHandlerFromAssertions(this, Clues.empty),
munitPrint(clue),
printObtainedAsStripMargin = true
)
Expand Down Expand Up @@ -291,6 +294,21 @@ trait Assertions extends MacroCompat.CompileErrorMacro {
)
}

private def exceptionHandlerFromAssertions(
assertions: Assertions,
clues: => Clues
): ComparisonFailExceptionHandler =
new ComparisonFailExceptionHandler {
def handle(
message: String,
obtained: String,
expected: String,
loc: Location
): Nothing = {
assertions.failComparison(message, obtained, expected, clues)(loc)
}
}

private val munitCapturedClues: mutable.ListBuffer[Clue[_]] =
mutable.ListBuffer.empty
def munitCaptureClues[T](thunk: => T): (T, Clues) =
Expand Down
3 changes: 2 additions & 1 deletion munit/shared/src/main/scala/munit/Clue.scala
Expand Up @@ -12,5 +12,6 @@ class Clue[+T](
object Clue extends MacroCompat.ClueMacro {
@deprecated("use fromValue instead", "1.0.0")
def empty[T](value: T): Clue[T] = fromValue(value)
def fromValue[T](value: T): Clue[T] = new Clue("", value, "")
def fromValue[T](value: T): Clue[T] =
new Clue("", value, "")
}
6 changes: 2 additions & 4 deletions munit/shared/src/main/scala/munit/Compare.scala
@@ -1,7 +1,5 @@
package munit

import munit.internal.difflib.Diffs
import munit.internal.difflib.ComparisonFailExceptionHandler
import scala.annotation.implicitNotFound

/**
Expand Down Expand Up @@ -61,7 +59,7 @@ trait Compare[A, B] {
}
// Attempt 1: custom pretty-printer that produces multiline output, which is
// optimized for line-by-line diffing.
Diffs.assertNoDiff(
DiffsAssetion.assertNoDiff(
assertions.munitPrint(obtained),
assertions.munitPrint(expected),
diffHandler,
Expand All @@ -71,7 +69,7 @@ trait Compare[A, B] {

// Attempt 2: try with `.toString` in case `munitPrint()` produces identical
// formatting for both values.
Diffs.assertNoDiff(
DiffsAssetion.assertNoDiff(
obtained.toString(),
expected.toString(),
diffHandler,
Expand Down
@@ -0,0 +1,10 @@
package munit

trait ComparisonFailExceptionHandler {
def handle(
message: String,
obtained: String,
expected: String,
location: Location
): Nothing
}
33 changes: 33 additions & 0 deletions munit/shared/src/main/scala/munit/DiffsAssetion.scala
@@ -0,0 +1,33 @@
package munit

import munit.diff.Diff

object DiffsAssetion {

def assertNoDiff(
obtained: String,
expected: String,
handler: ComparisonFailExceptionHandler,
title: String,
printObtainedAsStripMargin: Boolean
)(implicit loc: Location): Boolean = {
if (obtained.isEmpty && !expected.isEmpty) {
val msg =
s"""|Obtained empty output!
|=> Expected:
|$expected""".stripMargin
handler.handle(msg, obtained, expected, loc)
}
val diff = new Diff(obtained, expected)
if (diff.isEmpty) true
else {
handler.handle(
diff.createReport(title, printObtainedAsStripMargin),
obtained,
expected,
loc
)
}
}

}
10 changes: 8 additions & 2 deletions munit/shared/src/main/scala/munit/internal/console/Lines.scala
Expand Up @@ -44,10 +44,16 @@ class Lines extends Serializable {
.append(format(location.line - 1))
.append(slice(0))
.append('\n')
.append(AnsiColors.use(AnsiColors.Reversed))
.append(
munit.diff.console.AnsiColors
.use(munit.diff.console.AnsiColors.Reversed)
)
.append(format(location.line))
.append(slice(1))
.append(AnsiColors.use(AnsiColors.Reset))
.append(
munit.diff.console.AnsiColors
.use(munit.diff.console.AnsiColors.Reset)
)
if (slice.length >= 3)
out
.append('\n')
Expand Down
60 changes: 10 additions & 50 deletions munit/shared/src/main/scala/munit/internal/console/Printers.scala
@@ -1,13 +1,18 @@
// Adaptation of https://github.com/lihaoyi/PPrint/blob/e6a918c259ed7ae1998bbf58c360334a3f0157ca/pprint/src/pprint/Walker.scala
package munit.internal.console

import munit.internal.Compat
import munit.{EmptyPrinter, Location, Printable, Printer}

import scala.annotation.switch
import munit.Location
import munit.diff.Printer
import munit.diff.EmptyPrinter
import munit.diff.console.{Printers => DiffPrinters}
import munit.Clues
import munit.Printable
import munit.internal.Compat

object Printers {

import DiffPrinters._

def log(any: Any, printer: Printer = EmptyPrinter)(implicit
loc: Location
): Unit = {
Expand Down Expand Up @@ -123,7 +128,7 @@ object Printers {
}
}
loop(any, indent = 0)
AnsiColors.filterAnsi(out.toString())
munit.diff.console.AnsiColors.filterAnsi(out.toString())
}

private def printApply[T](
Expand Down Expand Up @@ -163,31 +168,6 @@ object Printers {
}
}

private def printString(
string: String,
out: StringBuilder,
printer: Printer
): Unit = {
val isMultiline = printer.isMultiline(string)
if (isMultiline) {
out.append('"')
out.append('"')
out.append('"')
out.append(string)
out.append('"')
out.append('"')
out.append('"')
} else {
out.append('"')
var i = 0
while (i < string.length()) {
printChar(string.charAt(i), out)
i += 1
}
out.append('"')
}
}

/**
* Pretty-prints this string with non-visible characters escaped.
*
Expand Down Expand Up @@ -216,24 +196,4 @@ object Printers {
out.toString()
}

private def printChar(
c: Char,
sb: StringBuilder,
isEscapeUnicode: Boolean = true
): Unit =
(c: @switch) match {
case '"' => sb.append("\\\"")
case '\\' => sb.append("\\\\")
case '\b' => sb.append("\\b")
case '\f' => sb.append("\\f")
case '\n' => sb.append("\\n")
case '\r' => sb.append("\\r")
case '\t' => sb.append("\\t")
case c =>
val isNonReadableAscii = c < ' ' || (c > '~' && isEscapeUnicode)
if (isNonReadableAscii && !Character.isLetter(c))
sb.append("\\u%04x".format(c.toInt))
else sb.append(c)
}

}

This file was deleted.

0 comments on commit f0a004c

Please sign in to comment.