forked from scalameta/munit-scalacheck
/
BaseFrameworkSuite.scala
116 lines (112 loc) · 3.42 KB
/
BaseFrameworkSuite.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package munit
import munit.diff.console.AnsiColors
import munit.internal.PlatformCompat
import sbt.testing.Event
import sbt.testing.EventHandler
import sbt.testing.Logger
import sbt.testing.TaskDef
import java.io.ByteArrayOutputStream
import java.io.PrintStream
import java.nio.charset.StandardCharsets
import java.util.regex.Pattern
import scala.concurrent.Future
import scala.util.control.NonFatal
abstract class BaseFrameworkSuite extends BaseSuite {
val systemOut = System.out
override def munitIgnore: Boolean = !BuildInfo.scalaVersion.startsWith("2.13")
def exceptionMessage(ex: Throwable): String = {
if (ex.getMessage() == null) "null"
else {
ex.getMessage()
.replace(
BuildInfo.sourceDirectory.toString(),
""
)
.replace('\\', '/')
}
}
def check(t: FrameworkTest): Unit = {
test(t.cls.getSimpleName().withTags(t.tags)) {
val baos = new ByteArrayOutputStream()
val out = new PrintStream(baos)
val logger = new Logger {
def ansiCodesSupported(): Boolean = false
def error(x: String): Unit = out.println(x)
def warn(x: String): Unit = out.println(x)
def info(x: String): Unit = out.println(x)
def debug(x: String): Unit = () // ignore debugging output
def trace(x: Throwable): Unit = out.println(x)
}
val framework = new Framework
val runner = framework.runner(
t.arguments ++ Array("+l"), // use sbt loggers
Array(),
PlatformCompat.getThisClassLoader
)
val tasks = runner.tasks(
Array(
new TaskDef(
t.cls.getName(),
framework.munitFingerprint,
false,
Array()
)
)
)
val events = new StringBuilder()
val eventHandler = new EventHandler {
def handle(event: Event): Unit = {
try {
events.append(t.onEvent(event))
val status = event.status().toString().toLowerCase()
val name = event.fullyQualifiedName()
events
.append("==> ")
.append(status)
.append(" ")
.append(name)
if (event.throwable().isDefined()) {
events
.append(" - ")
.append(exceptionMessage(event.throwable().get()))
}
events.append("\n")
} catch {
case NonFatal(e) =>
e.printStackTrace()
events.append(s"unexpected error: $e")
}
}
}
implicit val ec = munitExecutionContext
val elapsedTimePattern = Pattern.compile(" \\d+\\.\\d+s ?")
for {
_ <- tasks.foldLeft(Future.successful(())) { case (base, task) =>
base.flatMap(_ =>
PlatformCompat.executeAsync(
task,
eventHandler,
Array(logger)
)
)
}
} yield {
val stdout =
AnsiColors.filterAnsi(baos.toString(StandardCharsets.UTF_8.name()))
val obtained = AnsiColors.filterAnsi(
t.format match {
case SbtFormat =>
events.toString().replace("\"\"\"", "'''")
case StdoutFormat =>
elapsedTimePattern.matcher(stdout).replaceAll(" <elapsed time>")
}
)
assertNoDiff(
obtained,
t.expected,
stdout
)(t.location)
}
}(t.location)
}
}