diff --git a/detekt-core/src/main/resources/default-detekt-config.yml b/detekt-core/src/main/resources/default-detekt-config.yml
index ce69fea8ae2..307a1365ea6 100644
--- a/detekt-core/src/main/resources/default-detekt-config.yml
+++ b/detekt-core/src/main/resources/default-detekt-config.yml
@@ -505,6 +505,10 @@ style:
active: false
singleLine: 'never'
multiLine: 'always'
+ BracesOnWhenStatements:
+ active: false
+ singleLine: 'necessary'
+ multiLine: 'consistent'
CanBeNonNullable:
active: false
CascadingCallWrapping:
diff --git a/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/BracesOnWhenStatements.kt b/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/BracesOnWhenStatements.kt
new file mode 100644
index 00000000000..cf97dc368fc
--- /dev/null
+++ b/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/BracesOnWhenStatements.kt
@@ -0,0 +1,242 @@
+package io.gitlab.arturbosch.detekt.rules.style
+
+import io.gitlab.arturbosch.detekt.api.CodeSmell
+import io.gitlab.arturbosch.detekt.api.Config
+import io.gitlab.arturbosch.detekt.api.Debt
+import io.gitlab.arturbosch.detekt.api.Entity
+import io.gitlab.arturbosch.detekt.api.Issue
+import io.gitlab.arturbosch.detekt.api.Rule
+import io.gitlab.arturbosch.detekt.api.Severity
+import io.gitlab.arturbosch.detekt.api.config
+import io.gitlab.arturbosch.detekt.api.internal.Configuration
+import org.jetbrains.kotlin.psi.KtBlockExpression
+import org.jetbrains.kotlin.psi.KtWhenEntry
+import org.jetbrains.kotlin.psi.KtWhenExpression
+import org.jetbrains.kotlin.psi.psiUtil.siblings
+
+/**
+ * This rule detects `when` statements which do not comply with the specified policy.
+ * Keeping braces consistent will improve readability and avoid possible errors.
+ *
+ * Single-line `when` statement is:
+ * a `when` where each of the branches are single-line (has no line breaks `\n`).
+ *
+ * Multi-line `when` statement is:
+ * a `when` where at least one of the branches is multi-line (has a break line `\n`).
+ *
+ * Available options are:
+ * * `never`: forces no braces on any branch.
+ * _Tip_: this is very strict, it will force a simple expression, like a single function call / expression.
+ * Extracting a function for "complex" logic is one way to adhere to this policy.
+ * * `necessary`: forces no braces on any branch except where necessary for multi-statement branches.
+ * * `consistent`: ensures that braces are consistent within `when` statement.
+ * If there are braces on one of the branches, all branches should have it.
+ * * `always`: forces braces on all branches.
+ *
+ *
+ * // singleLine = 'never'
+ * when (a) {
+ * 1 -> { f1() } // Not allowed.
+ * 2 -> f2()
+ * }
+ * // multiLine = 'never'
+ * when (a) {
+ * 1 -> { // Not allowed.
+ * f1()
+ * }
+ * 2 -> f2()
+ * }
+ * // singleLine = 'necessary'
+ * when (a) {
+ * 1 -> { f1() } // Unnecessary braces.
+ * 2 -> f2()
+ * }
+ * // multiLine = 'necessary'
+ * when (a) {
+ * 1 -> { // Unnecessary braces.
+ * f1()
+ * }
+ * 2 -> f2()
+ * }
+ *
+ * // singleLine = 'consistent'
+ * when (a) {
+ * 1 -> { f1() }
+ * 2 -> f2()
+ * }
+ * // multiLine = 'consistent'
+ * when (a) {
+ * 1 ->
+ * f1() // Missing braces.
+ * 2 -> {
+ * f2()
+ * f3()
+ * }
+ * }
+ *
+ * // singleLine = 'always'
+ * when (a) {
+ * 1 -> { f1() }
+ * 2 -> f2() // Missing braces.
+ * }
+ * // multiLine = 'always'
+ * when (a) {
+ * 1 ->
+ * f1() // Missing braces.
+ * 2 -> {
+ * f2()
+ * f3()
+ * }
+ * }
+ *
+ *
+ *
+ *
+ * // singleLine = 'never'
+ * when (a) {
+ * 1 -> f1()
+ * 2 -> f2()
+ * }
+ * // multiLine = 'never'
+ * when (a) {
+ * 1 ->
+ * f1()
+ * 2 -> f2()
+ * }
+ * // singleLine = 'necessary'
+ * when (a) {
+ * 1 -> f1()
+ * 2 -> { f2(); f3() } // Necessary braces because of multiple statements.
+ * }
+ * // multiLine = 'necessary'
+ * when (a) {
+ * 1 ->
+ * f1()
+ * 2 -> { // Necessary braces because of multiple statements.
+ * f2()
+ * f3()
+ * }
+ * }
+ *
+ * // singleLine = 'consistent'
+ * when (a) {
+ * 1 -> { f1() }
+ * 2 -> { f2() }
+ * }
+ * when (a) {
+ * 1 -> f1()
+ * 2 -> f2()
+ * }
+ * // multiLine = 'consistent'
+ * when (a) {
+ * 1 -> {
+ * f1()
+ * }
+ * 2 -> {
+ * f2()
+ * f3()
+ * }
+ * }
+ *
+ * // singleLine = 'always'
+ * when (a) {
+ * 1 -> { f1() }
+ * 2 -> { f2() }
+ * }
+ * // multiLine = 'always'
+ * when (a) {
+ * 1 -> {
+ * f1()
+ * }
+ * 2 -> {
+ * f2()
+ * f3()
+ * }
+ * }
+ *
+ *
+ */
+class BracesOnWhenStatements(config: Config = Config.empty) : Rule(config) {
+ override val issue = Issue(
+ javaClass.simpleName,
+ Severity.Style,
+ "Braces do not comply with the specified policy",
+ Debt.FIVE_MINS
+ )
+
+ @Configuration("single-line braces policy")
+ private val singleLine: BracePolicy by config("necessary") { BracePolicy.getValue(it) }
+
+ @Configuration("multi-line braces policy")
+ private val multiLine: BracePolicy by config("consistent") { BracePolicy.getValue(it) }
+
+ override fun visitWhenExpression(expression: KtWhenExpression) {
+ super.visitWhenExpression(expression)
+
+ validate(expression.entries, policy(expression))
+ }
+
+ private fun validate(branches: List, policy: BracePolicy) {
+ val violators = when (policy) {
+ BracePolicy.Always -> {
+ branches.filter { !it.hasBraces() }
+ }
+
+ BracePolicy.Necessary -> {
+ branches.filter { !it.isMultiStatement() && it.hasBraces() }
+ }
+
+ BracePolicy.Never -> {
+ branches.filter { it.hasBraces() }
+ }
+
+ BracePolicy.Consistent -> {
+ val braces = branches.count { it.hasBraces() }
+ val noBraces = branches.count { !it.hasBraces() }
+ if (braces != 0 && noBraces != 0) {
+ branches.take(1)
+ } else {
+ emptyList()
+ }
+ }
+ }
+ violators.forEach { report(it, policy) }
+ }
+
+ private fun KtWhenEntry.hasBraces(): Boolean = expression is KtBlockExpression
+
+ private fun KtWhenEntry.isMultiStatement(): Boolean =
+ expression.let { it is KtBlockExpression && it.statements.size > 1 }
+
+ private fun policy(expression: KtWhenExpression): BracePolicy {
+ val isMultiLine = expression.entries.any { branch ->
+ requireNotNull(branch.arrow) { "When branch ${branch.text} has no arrow!" }
+ .siblings(forward = true, withItself = false)
+ .any { it.textContains('\n') }
+ }
+ return if (isMultiLine) multiLine else singleLine
+ }
+
+ private fun report(violator: KtWhenEntry, policy: BracePolicy) {
+ val reported = when (policy) {
+ BracePolicy.Consistent -> (violator.parent as KtWhenExpression).whenKeyword
+ BracePolicy.Always,
+ BracePolicy.Necessary,
+ BracePolicy.Never -> requireNotNull(violator.arrow) { "When branch ${violator.text} has no arrow!" }
+ }
+ report(CodeSmell(issue, Entity.from(reported), policy.message))
+ }
+
+ enum class BracePolicy(val config: String, val message: String) {
+ Always("always", "Missing braces on this branch, add them."),
+ Consistent("consistent", "Inconsistent braces, make sure all branches either have or don't have braces."),
+ Necessary("necessary", "Extra braces exist on this branch, remove them."),
+ Never("never", "Extra braces exist on this branch, remove them.");
+
+ companion object {
+ fun getValue(arg: String): BracePolicy =
+ values().singleOrNull { it.config == arg }
+ ?: error("Unknown value $arg, allowed values are: ${values().joinToString("|")}")
+ }
+ }
+}
diff --git a/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/StyleGuideProvider.kt b/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/StyleGuideProvider.kt
index 02a5ae4374f..7cbf54fd2b7 100644
--- a/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/StyleGuideProvider.kt
+++ b/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/StyleGuideProvider.kt
@@ -78,6 +78,7 @@ class StyleGuideProvider : DefaultRuleSetProvider {
MayBeConst(config),
PreferToOverPairSyntax(config),
BracesOnIfStatements(config),
+ BracesOnWhenStatements(config),
MandatoryBracesLoops(config),
NullableBooleanCheck(config),
VarCouldBeVal(config),
diff --git a/detekt-rules-style/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/style/BracesOnWhenStatementsSpec.kt b/detekt-rules-style/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/style/BracesOnWhenStatementsSpec.kt
new file mode 100644
index 00000000000..58485331a98
--- /dev/null
+++ b/detekt-rules-style/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/style/BracesOnWhenStatementsSpec.kt
@@ -0,0 +1,1006 @@
+package io.gitlab.arturbosch.detekt.rules.style
+
+import io.gitlab.arturbosch.detekt.api.TextLocation
+import io.gitlab.arturbosch.detekt.rules.style.BracesOnWhenStatements.BracePolicy
+import io.gitlab.arturbosch.detekt.rules.style.BracesOnWhenStatementsSpec.Companion.NOT_RELEVANT
+import io.gitlab.arturbosch.detekt.rules.style.BracesOnWhenStatementsSpec.Companion.test
+import io.gitlab.arturbosch.detekt.test.TestConfig
+import io.gitlab.arturbosch.detekt.test.assertThat
+import io.gitlab.arturbosch.detekt.test.compileAndLint
+import io.gitlab.arturbosch.detekt.test.lint
+import org.assertj.core.api.Assertions.assertThat
+import org.assertj.core.api.Assertions.assertThatCode
+import org.junit.jupiter.api.DynamicContainer.dynamicContainer
+import org.junit.jupiter.api.DynamicNode
+import org.junit.jupiter.api.DynamicTest.dynamicTest
+import org.junit.jupiter.api.Nested
+import org.junit.jupiter.api.Test
+import org.junit.jupiter.api.TestFactory
+
+/**
+ * Note: this class makes extensive use of dynamic tests and containers, few tips for maintenance:
+ * * Numbers are made relative to the code snippet rather than the whole code passed to Kotlin compiler. [test].
+ * * To debug a specific test case, remove all other test cases from the same method.
+ * * Test coverage is added for all possible configuration options, see [NOT_RELEVANT].
+ */
+@Suppress(
+ "ClassName",
+ "LongMethod",
+ "CommentOverPrivateProperty"
+)
+class BracesOnWhenStatementsSpec {
+
+ @Test
+ fun `validate behavior of occurrence function`() {
+ val code = "fun f() { if (true) else if (true) if (true) true }"
+ assertThat("if"(1)(code)).isEqualTo(10 to 12)
+ assertThat("if"(2)(code)).isEqualTo(25 to 27)
+ assertThat("if"(3)(code)).isEqualTo(35 to 37)
+ assertThat("else"(1)(code)).isEqualTo(20 to 24)
+ assertThatCode { "if"(4)(code) }.isInstanceOf(IllegalArgumentException::class.java)
+ assertThatCode { "if"(0)(code) }.isInstanceOf(IllegalArgumentException::class.java)
+ assertThatCode { "else"(2)(code) }.isInstanceOf(IllegalArgumentException::class.java)
+ }
+
+ @Nested
+ inner class specialTestCases {
+ @TestFactory
+ fun `special when conditions`() = flag(
+ """
+ when (0) {
+ 1, 2 -> { println() }
+ in 3..4 -> { println() }
+ is Int -> { println() }
+ 5, 6, in 7..8, is Number -> { println() }
+ else -> { println() }
+ }
+ """.trimIndent(),
+ "->"(1),
+ "->"(2),
+ "->"(3),
+ "->"(4),
+ "->"(5),
+ )
+
+ private fun flag(code: String, vararg locations: (String) -> Pair) =
+ testCombinations(BracePolicy.Never.config, BracePolicy.Always.config, code, *locations)
+
+ @TestFactory
+ fun `extra line break between branches`() =
+ flag(
+ """
+ when (1) {
+
+ 1 -> println()
+
+ 2 -> println()
+
+ else -> println()
+
+ }
+ """.trimIndent(),
+ *NOTHING
+ )
+
+ @TestFactory
+ fun `nested when inside when case block`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ when (2) { 1 -> println(); else -> { println() } }
+ }
+ 2 -> println()
+ else -> { println() }
+ }
+ """.trimIndent(),
+ "->"(3),
+ "->"(4),
+ )
+
+ @TestFactory
+ fun `nested when inside when subject`() =
+ flag(
+ """
+ when (when (2) { 1 -> { 1 }; else -> { 2 } }) {
+ 1 -> {
+ println()
+ }
+ 2 -> println()
+ else -> { println() }
+ }
+ """.trimIndent(),
+ "->"(1),
+ "->"(2),
+ "->"(4),
+ )
+
+ @TestFactory
+ fun `nested when inside when condition`() =
+ flag(
+ """
+ fun f(s: String?) {
+ when {
+ when { s != null -> { s.length }; else -> { 0 } } > 5 -> {
+ true
+ }
+ when(s) { "foo" -> { 1 }; "bar" -> { 1 }; else -> { 2 } } == 1 -> false
+ else -> { null }
+ }
+ }
+ """.trimIndent(),
+ "->"(1),
+ "->"(2),
+ "->"(4),
+ "->"(5),
+ "->"(6),
+ "->"(7),
+ )
+
+ @TestFactory
+ fun `weird curly formatting for multiline whens`() =
+ flag(
+ """
+ when (1) {
+ 1 ->
+ {
+ println()
+ }
+ 2 -> { println()
+ }
+ 3 -> {
+ println() }
+ else ->
+ { println() }
+ }
+ """.trimIndent(),
+ *NOTHING,
+ )
+ }
+
+ @Nested
+ inner class singleLine {
+
+ @Nested
+ inner class `=always` {
+
+ private fun flag(code: String, vararg locations: (String) -> Pair) =
+ testCombinations(BracePolicy.Always.config, NOT_RELEVANT, code, *locations)
+
+ @TestFactory
+ fun `missing braces are flagged`() =
+ flag(
+ """
+ when (1) {
+ 1 -> println()
+ 2 -> { println(); println() }
+ }
+ """.trimIndent(),
+ "->"(1),
+ )
+
+ @TestFactory
+ fun `existing braces are accepted`() =
+ flag(
+ """
+ when (1) {
+ 1 -> { println() }
+ 2 -> { println() }
+ else -> { println(); println() }
+ }
+ """.trimIndent(),
+ *NOTHING,
+ )
+
+ @TestFactory
+ fun `partially missing braces are flagged (first branch)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> println()
+ 2 -> { println() }
+ else -> { println(); println() }
+ }
+ """.trimIndent(),
+ "->"(1),
+ )
+
+ @TestFactory
+ fun `partially missing braces are flagged (last branch)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> { println() }
+ 2 -> { println(); println() }
+ else -> println()
+ }
+ """.trimIndent(),
+ "->"(3),
+ )
+
+ @TestFactory
+ fun `partially missing braces are flagged (middle branches)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> { println(); println() }
+ 2 -> println()
+ else -> { println() }
+ }
+ """.trimIndent(),
+ "->"(2),
+ )
+ }
+
+ @Nested
+ inner class `=never` {
+
+ private fun flag(code: String, vararg locations: (String) -> Pair) =
+ testCombinations(BracePolicy.Never.config, NOT_RELEVANT, code, *locations)
+
+ @TestFactory fun `no braces are accepted`() =
+ flag(
+ """
+ when (1) {
+ 1 -> println()
+ 2 -> println()
+ else -> println()
+ }
+ """.trimIndent(),
+ *NOTHING,
+ )
+
+ @TestFactory
+ fun `existing braces are flagged`() =
+ flag(
+ """
+ when (1) {
+ 1 -> { println() }
+ 2 -> { println() }
+ else -> { println(); println() }
+ }
+ """.trimIndent(),
+ "->"(1),
+ "->"(2),
+ "->"(3),
+ )
+
+ @TestFactory
+ fun `partially extra braces are flagged (first branch)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> { println(); println() }
+ 2 -> println()
+ else -> println()
+ }
+ """.trimIndent(),
+ "->"(1),
+ )
+
+ @TestFactory
+ fun `partially extra braces are flagged (last branch)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> println()
+ 2 -> println()
+ else -> { println() }
+ }
+ """.trimIndent(),
+ "->"(3),
+ )
+
+ @TestFactory
+ fun `partially extra braces are flagged (middle branches)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> println()
+ 2 -> { println() }
+ else -> println()
+ }
+ """.trimIndent(),
+ "->"(2),
+ )
+ }
+
+ @Nested
+ inner class `=necessary` {
+
+ private fun flag(code: String, vararg locations: (String) -> Pair) =
+ testCombinations(BracePolicy.Necessary.config, NOT_RELEVANT, code, *locations)
+
+ @TestFactory
+ fun `no braces are accepted`() =
+ flag(
+ """
+ when (1) {
+ 1 -> println()
+ 2 -> println()
+ else -> { println(); println() }
+ }
+ """.trimIndent(),
+ *NOTHING,
+ )
+
+ @TestFactory
+ fun `existing braces are flagged`() =
+ flag(
+ """
+ when (1) {
+ 1 -> { println() }
+ 2 -> { println() }
+ else -> { println() }
+ }
+ """.trimIndent(),
+ "->"(1),
+ "->"(2),
+ "->"(3),
+ )
+
+ @TestFactory
+ fun `partially extra braces are flagged (first branch)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> { println() }
+ 2 -> println()
+ else -> { println(); println() }
+ }
+ """.trimIndent(),
+ "->"(1),
+ )
+
+ @TestFactory
+ fun `partially extra braces are flagged (last branch)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> { println(); println() }
+ 2 -> println()
+ else -> { println() }
+ }
+ """.trimIndent(),
+ "->"(3),
+ )
+
+ @TestFactory
+ fun `partially extra braces are flagged (middle branches)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> println()
+ 2 -> { println() }
+ else -> { println(); println() }
+ }
+ """.trimIndent(),
+ "->"(2),
+ )
+ }
+
+ @Nested
+ inner class `=consistent` {
+
+ private fun flag(code: String, vararg locations: (String) -> Pair) =
+ testCombinations(BracePolicy.Consistent.config, NOT_RELEVANT, code, *locations)
+
+ @TestFactory
+ fun `no braces are accepted`() =
+ flag(
+ """
+ when (1) {
+ 1 -> println()
+ 2 -> println()
+ else -> println()
+ }
+ """.trimIndent(),
+ *NOTHING,
+ )
+
+ @TestFactory
+ fun `existing braces are accepted`() =
+ flag(
+ """
+ when (1) {
+ 1 -> { println() }
+ 2 -> { println() }
+ else -> { println() }
+ }
+ """.trimIndent(),
+ *NOTHING,
+ )
+
+ @TestFactory
+ fun `partial braces are flagged (first branch)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> println()
+ 2 -> { println() }
+ else -> { println() }
+ }
+ """.trimIndent(),
+ "when"(1),
+ )
+
+ @TestFactory
+ fun `partial braces are flagged (last branch)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> { println() }
+ 2 -> { println() }
+ else -> println()
+ }
+ """.trimIndent(),
+ "when"(1),
+ )
+
+ @TestFactory
+ fun `partial braces are flagged (middle branches)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> { println() }
+ 2 -> println()
+ else -> { println() }
+ }
+ """.trimIndent(),
+ "when"(1),
+ )
+ }
+ }
+
+ @Nested
+ inner class multiLine {
+
+ @Nested
+ inner class `=always` {
+
+ private fun flag(code: String, vararg locations: (String) -> Pair) =
+ testCombinations(NOT_RELEVANT, BracePolicy.Always.config, code, *locations)
+
+ @TestFactory
+ fun `missing braces are flagged`() =
+ flag(
+ """
+ when (1) {
+ 1 ->
+ println()
+ 2 -> { println() }
+ else -> {
+ println()
+ println()
+ }
+ }
+ """.trimIndent(),
+ "->"(1),
+ )
+
+ @TestFactory
+ fun `existing braces are accepted`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ println()
+ }
+ 2 -> { println() }
+ else -> {
+ println()
+ println()
+ }
+ }
+ """.trimIndent(),
+ *NOTHING,
+ )
+
+ @TestFactory
+ fun `partially missing braces are flagged (first branch)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ println()
+ }
+ 2 -> println()
+ else -> {
+ println()
+ println()
+ }
+ }
+ """.trimIndent(),
+ "->"(2),
+ )
+
+ @TestFactory
+ fun `partially missing braces are flagged (last branch)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ println()
+ }
+ 2 -> {
+ println()
+ println()
+ }
+ else ->
+ println()
+ }
+ """.trimIndent(),
+ "->"(3),
+ )
+
+ @TestFactory
+ fun `partially missing braces are flagged (middle branches)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ println()
+ }
+ 2 ->
+ println()
+ else -> {
+ println()
+ println()
+ }
+ }
+ """.trimIndent(),
+ "->"(2),
+ )
+ }
+
+ @Nested
+ inner class `=never` {
+
+ private fun flag(code: String, vararg locations: (String) -> Pair) =
+ testCombinations(NOT_RELEVANT, BracePolicy.Never.config, code, *locations)
+
+ @TestFactory
+ fun `no braces are accepted`() =
+ flag(
+ """
+ when (1) {
+ 1 ->
+ println()
+ 2 ->
+ println()
+ else ->
+ println()
+ }
+ """.trimIndent(),
+ *NOTHING,
+ )
+
+ @TestFactory
+ fun `existing braces are flagged`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ println()
+ }
+ 2 -> {
+ println()
+ }
+ else -> {
+ println()
+ }
+ }
+ """.trimIndent(),
+ "->"(1),
+ "->"(2),
+ "->"(3),
+ )
+
+ @TestFactory
+ fun `partially extra braces are flagged (first branch)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ println()
+ }
+ 2 ->
+ println()
+
+ else ->
+ println()
+ }
+ """.trimIndent(),
+ "->"(1),
+ )
+
+ @TestFactory
+ fun `partially extra braces are flagged (last branch)`() =
+ flag(
+ """
+ when (1) {
+ 1 ->
+ println()
+ 2 ->
+ println()
+ else -> {
+ println()
+ println()
+ }
+ }
+ """.trimIndent(),
+ "->"(3),
+ )
+
+ @TestFactory
+ fun `partially extra braces are flagged (middle branches)`() =
+ flag(
+ """
+ when (1) {
+ 1 ->
+ println()
+
+ 2 -> {
+ println()
+ println()
+ }
+
+ else ->
+ println()
+ }
+ """.trimIndent(),
+ "->"(2),
+ )
+ }
+
+ @Nested
+ inner class `=necessary` {
+
+ private fun flag(code: String, vararg locations: (String) -> Pair) =
+ testCombinations(NOT_RELEVANT, BracePolicy.Necessary.config, code, *locations)
+
+ @TestFactory
+ fun `no braces are accepted`() =
+ flag(
+ """
+ when (1) {
+ 1 ->
+ println()
+ 2 ->
+ println()
+ else ->
+ println()
+ }
+ """.trimIndent(),
+ *NOTHING,
+ )
+
+ @TestFactory
+ fun `existing braces are flagged`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ println()
+ }
+ 2 -> {
+ println()
+ }
+ else -> {
+ println()
+ }
+ }
+ """.trimIndent(),
+ "->"(1),
+ "->"(2),
+ "->"(3),
+ )
+
+ @TestFactory
+ fun `partially extra braces are flagged (first branch)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ println()
+ }
+ 2 -> {
+ println()
+ println()
+ }
+ else -> {
+ println()
+ println()
+ }
+ }
+ """.trimIndent(),
+ "->"(1),
+ )
+
+ @TestFactory
+ fun `partially extra braces are flagged (last branch)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ println()
+ println()
+ }
+ 2 -> {
+ println()
+ println()
+ }
+ else -> {
+ println()
+ }
+ }
+ """.trimIndent(),
+ "->"(3),
+ )
+
+ @TestFactory
+ fun `partially extra braces are flagged (middle branches)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ println()
+ println()
+ }
+ 2 -> {
+ println()
+ }
+ else ->
+ println()
+
+ }
+ """.trimIndent(),
+ "->"(2),
+ )
+
+ @TestFactory
+ fun `existing braces are not flagged when necessary`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ println()
+ println()
+ }
+ 2 -> {
+ println()
+ println()
+ }
+ else -> {
+ println()
+ println()
+ }
+ }
+ """.trimIndent(),
+ *NOTHING,
+ )
+ }
+
+ @Nested
+ inner class `=consistent` {
+
+ private fun flag(code: String, vararg locations: (String) -> Pair) =
+ testCombinations(NOT_RELEVANT, BracePolicy.Consistent.config, code, *locations)
+
+ @TestFactory
+ fun `no braces are accepted`() =
+ flag(
+ """
+ when (1) {
+ 1 ->
+ println()
+ 2 ->
+ println()
+ else ->
+ println()
+ }
+ """.trimIndent(),
+ *NOTHING,
+ )
+
+ @TestFactory
+ fun `existing braces are accepted`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ println()
+ }
+ 2 -> { println() }
+ else -> {
+ println()
+ println()
+ }
+ }
+ """.trimIndent(),
+ *NOTHING,
+ )
+
+ @TestFactory
+ fun `inconsistent braces are flagged`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ println()
+ }
+ 2 -> {
+ println()
+ }
+ else ->
+ println()
+ }
+ """.trimIndent(),
+ "when"(1),
+ )
+
+ @TestFactory
+ fun `inconsistent braces are flagged (first branch)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ println()
+ }
+ 2 ->
+ println()
+ else ->
+ println()
+ }
+ """.trimIndent(),
+ "when"(1),
+ )
+
+ @TestFactory
+ fun `inconsistent braces are flagged (last branch)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ println()
+ }
+ 2 -> {
+ println()
+ }
+ else ->
+ println()
+ }
+ """.trimIndent(),
+ "when"(1),
+ )
+
+ @TestFactory
+ fun `inconsistent braces are flagged (middle branches)`() =
+ flag(
+ """
+ when (1) {
+ 1 -> {
+ println()
+ }
+ 2 -> println()
+ else -> {
+ println()
+ }
+ }
+ """.trimIndent(),
+ "when"(1),
+ )
+ }
+ }
+
+ companion object {
+
+ /**
+ * Not relevant means, that it should be covered, but for that specific test case the value doesn't matter.
+ * This needs to be covered still, to make sure it never becomes relevant.
+ * In this rule, configuration options should be separate for each config, not coupled with each other.
+ */
+ private const val NOT_RELEVANT: String = "*"
+
+ /**
+ * Nothing is expected to be flagged as a finding.
+ */
+ private val NOTHING: Array<(String) -> Pair> = emptyArray()
+
+ private fun createSubject(singleLine: String, multiLine: String): BracesOnWhenStatements {
+ val config = TestConfig(
+ "singleLine" to singleLine,
+ "multiLine" to multiLine
+ )
+ return BracesOnWhenStatements(config)
+ }
+
+ private fun options(option: String): List =
+ if (option == NOT_RELEVANT) {
+ BracePolicy.values().map { it.config }
+ } else {
+ listOf(option)
+ }
+
+ private fun testCombinations(
+ singleLine: String,
+ multiLine: String,
+ code: String,
+ vararg locations: (String) -> Pair
+ ): DynamicNode {
+ val codeLocation = locations.map { it(code) }.toTypedArray()
+ // Separately compile the code because otherwise all the combinations would compile them again and again.
+ val compileTest = dynamicTest("Compiles: $code") {
+ BracesOnWhenStatements().compileAndLint(code)
+ }
+ val validationTests = createBraceTests(singleLine, multiLine) { rule ->
+ rule.test(code, *codeLocation)
+ }
+ val locationString = if (NOTHING.contentEquals(codeLocation)) {
+ "nothing"
+ } else {
+ codeLocation.map { TextLocation(it.first, it.second) }.toString()
+ }
+ return dynamicContainer("flags $locationString in `$code`", validationTests + compileTest)
+ }
+
+ private fun BracesOnWhenStatements.test(code: String, vararg locations: Pair) {
+ // This creates a 10 character prefix (signature/9, space/1) for every code example.
+ // Note: not compileAndLint for performance reasons, compilation is in a separate test.
+ val findings = lint("fun f() { $code }")
+ // Offset text locations by the above prefix, it results in 0-indexed locations.
+ val offset = 10
+ assertThat(findings)
+ .hasTextLocations(
+ *(locations.map { it.first + offset to it.second + offset }).toTypedArray()
+ )
+ }
+
+ /**
+ * Generates a list of tests for the given brace policy combinations.
+ * The expectations in the test will be the same for all combinations.
+ *
+ * @see options for how the arguments are interpreted.
+ */
+ private fun createBraceTests(
+ singleLine: String,
+ multiLine: String,
+ test: (BracesOnWhenStatements) -> Unit
+ ): List {
+ val singleOptions = options(singleLine)
+ val multiOptions = options(multiLine)
+ require(singleOptions.isNotEmpty() && multiOptions.isNotEmpty()) {
+ "No options to test: $singleLine -> $singleOptions, $multiLine -> $multiOptions"
+ }
+ return singleOptions.flatMap { singleLineOption ->
+ multiOptions.map { multiLineOption ->
+ val trace = Exception("Async stack trace of TestFactory")
+ dynamicTest("singleLine=$singleLineOption, multiLine=$multiLineOption") {
+ try {
+ // Note: if you jumped here from a failed test's stack trace,
+ // select the "Async stack trace of TestFactory" cause's
+ // last relevant stack line to jump to the actual test code.
+ test(createSubject(singleLineOption, multiLineOption))
+ } catch (e: Throwable) {
+ generateSequence(e) { it.cause }.last().initCause(trace)
+ throw e
+ }
+ }
+ }
+ }
+ }
+
+ operator fun String.invoke(ordinal: Int): (String) -> Pair =
+ { code ->
+ fun String.next(string: String, start: Int): Int? = indexOf(string, start).takeIf { it != -1 }
+
+ val indices = generateSequence(code.next(this, 0)) { startIndex ->
+ code.next(this, startIndex + 1)
+ }
+ val index = requireNotNull(indices.elementAtOrNull(ordinal - 1)) {
+ "There's no $ordinal. occurrence of '$this' in '$code'"
+ }
+ index to index + this.length
+ }
+ }
+}