Skip to content

Commit

Permalink
RemoveScala3OptionalBraces: handle fewer braces
Browse files Browse the repository at this point in the history
In some cases, invoke this rule from redundant-braces, since that rule
is guaranteed to run before remove-optional-braces.
  • Loading branch information
kitbellew authored and albertikm committed Mar 10, 2024
1 parent b522e67 commit fc0c020
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 131 deletions.
Expand Up @@ -106,6 +106,7 @@ class RedundantBraces(implicit val ftoks: FormatTokens)

private def onLeftParen(implicit
ft: FormatToken,
session: Session,
style: ScalafmtConfig
): Replacement = {
val rt = ft.right
Expand Down Expand Up @@ -137,7 +138,15 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
case _ => None
}

lpFunction.orElse(lpPartialFunction).orNull
val repl = lpFunction.orElse(lpPartialFunction).orNull
(rtOwner match {
case ac: Term.ArgClause
if repl != null && repl.ft.right.is[Token.LeftBrace] =>
session.rule[RemoveScala3OptionalBraces].flatMap { r =>
Option(r.onLeftForArgClause(ac, None))
}
case _ => None
}).getOrElse(repl)
}

private def onRightParen(
Expand Down Expand Up @@ -168,7 +177,8 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
owner match {
case t: Term.FunctionTerm if t.tokens.last.is[Token.RightBrace] =>
if (!okToRemoveFunctionInApplyOrInit(t)) null
else if (okToReplaceFunctionInSingleArgApply(t)) replaceWithLeftParen
else if (okToReplaceFunctionInSingleArgApply(t))
handleFuncInSingleArgApply(t)
else removeToken
case t: Term.PartialFunction if t.parent.exists { p =>
SingleArgInBraces.orBlock(p).contains(t) &&
Expand All @@ -179,7 +189,7 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
t.parent match {
case Some(f: Term.FunctionTerm)
if okToReplaceFunctionInSingleArgApply(f) =>
replaceWithLeftParen
handleFuncInSingleArgApply(f)
case Some(_: Term.Interpolate) => handleInterpolation
case _ =>
if (processBlock(t)) removeToken else null
Expand Down Expand Up @@ -553,4 +563,20 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
ftoks(t.name.tokens.head, -1).left.is[Token.Dot]
}

private def handleFuncInSingleArgApply(
f: Term.FunctionTerm
)(implicit ft: FormatToken, session: Session): Replacement =
f.parent match {
case Some(ac: Term.ArgClause) if {
val acFt = ftoks.tokenJustBefore(ac)
acFt.right.is[Token.LeftParen] &&
session.claimedRule(acFt).exists { x =>
x.ft.right.is[Token.Colon] &&
x.how == ReplacementType.Replace
}
} =>
removeToken
case _ => replaceWithLeftParen
}

}
@@ -1,5 +1,9 @@
package org.scalafmt.rewrite

import scala.annotation.tailrec
import scala.collection.mutable
import scala.reflect.ClassTag

import scala.meta._
import scala.meta.tokens.Token

Expand Down Expand Up @@ -47,6 +51,18 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens)
if (t.parent.exists(_.is[Defn.Given])) removeToken
else
replaceToken(":")(new Token.Colon(x.input, x.dialect, x.start))
case t: Term.ArgClause => onLeftForArgClause(t, Some(false))
case t: Term.PartialFunction =>
t.parent match {
case Some(p: Term.ArgClause) if (p.tokens.head match {
case px: Token.LeftBrace => px eq x
case px: Token.LeftParen =>
shouldRewriteArgClauseWithLeftParen[RedundantBraces](px)
case _ => false
}) =>
onLeftForArgClause(p, sharesBracesWithArg = Some(true))
case _ => null
}
case _: Term.For if allowOldSyntax || {
val rbFt = ftoks(ftoks.matching(ft.right))
ftoks.nextNonComment(rbFt).right.is[Token.KwDo]
Expand Down Expand Up @@ -83,6 +99,8 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens)
})
ft.right match {
case _ if notOkToRewrite => None
case _: Token.RightParen =>
Some((left, removeToken))
case x: Token.RightBrace =>
val replacement = ft.meta.rightOwner match {
case _: Term.For if allowOldSyntax && !nextFt.right.is[Token.KwDo] =>
Expand All @@ -94,9 +112,11 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens)
}
}

private def onLeftForBlock(
tree: Term.Block
)(implicit ft: FormatToken, style: ScalafmtConfig): Replacement =
private def onLeftForBlock(tree: Term.Block)(implicit
ft: FormatToken,
session: Session,
style: ScalafmtConfig
): Replacement =
tree.parent.fold(null: Replacement) {
case t: Term.If =>
val ok = ftoks.prevNonComment(ft).left match {
Expand Down Expand Up @@ -134,7 +154,73 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens)
else if (ftoks.prevNonComment(ft).left.is[Token.Equals]) removeToken
else null
case p: Tree.WithBody => if (p.body eq tree) removeToken else null
case p: Term.ArgClause =>
p.tokens.head match {
case px: Token.LeftBrace =>
onLeftForArgClause(p, Some(px eq ft.right))
case px: Token.LeftParen
if shouldRewriteArgClauseWithLeftParen[RedundantParens](px) =>
onLeftForArgClause(p, Some(true))
case _ => null
}
case _ => null
}

private def shouldRewriteArgClauseWithLeftParen[A <: Rule](
lp: Token
)(implicit ft: FormatToken, session: Session, tag: ClassTag[A]) = {
val prevFt = ftoks.prevNonComment(ft)
prevFt.left.eq(lp) && session
.claimedRule(prevFt.meta.idx - 1)
.exists(x => tag.runtimeClass.isInstance(x.rule))
}

def onLeftForArgClause(
tree: Term.ArgClause,
sharesBracesWithArg: Option[Boolean]
)(implicit ft: FormatToken, style: ScalafmtConfig): Replacement = {
val rob = style.rewrite.scala3.removeOptionalBraces
val maxStats = rob.fewerBracesMaxStats
if (maxStats == 0) return null
@tailrec
def checkCountInRange(queue: mutable.Queue[Tree], cnt: Int): Boolean =
if (cnt > maxStats) false
else if (queue.isEmpty) cnt >= rob.fewerBracesMinStats
else {
val next = queue.dequeue()
def enqueue(trees: Iterable[Tree]*): Int = {
val len = queue.length
trees.foreach(queue ++= _)
queue.length - len
}
val delta = next match {
case x: Term.Block => enqueue(x.stats)
case x: Tree.WithBody => enqueue(x.body :: Nil) - 1
case x: Tree.WithCases => enqueue(x.cases)
case x: Tree.WithEnums => enqueue(x.enums)
case x: Stat.WithTemplate => enqueue(x.templ.early, x.templ.stats)
case _ => 0
}
checkCountInRange(queue, cnt + delta)
}
def ok = tree.values match {
case arg :: Nil =>
val queue = mutable.Queue.empty[Tree]
queue += arg
val ignoreArg = sharesBracesWithArg
.getOrElse(arg.tokens.headOption.contains(tree.tokens.head))
checkCountInRange(queue, if (ignoreArg) 0 else 1)
case _ => false
}
tree.parent match {
case Some(t: Term.Apply) if (t.parent match {
case Some(pp: Term.Apply) if pp.fun eq t => false
case _ => style.dialect.allowFewerBraces && ok
}) =>
val x = ft.right // `{` or `(`
replaceToken(":")(new Token.Colon(x.input, x.dialect, x.start))
case _ => null
}
}

}
50 changes: 21 additions & 29 deletions scalafmt-tests/src/test/resources/scala3/FewerBraces.stat
Expand Up @@ -1827,10 +1827,9 @@ foo
.mtd1 {
x + 1
}
.mtd2 {
x + 1
x + 2
}
.mtd2:
x + 1
x + 2
.mtd3 {
x + 1
x + 2
Expand Down Expand Up @@ -1861,10 +1860,9 @@ foo
.mtd1 { x =>
x + 1
}
.mtd2 { x =>
.mtd2: x =>
x + 1
x + 2
}
.mtd3 { x =>
x + 1
x + 2
Expand Down Expand Up @@ -1902,10 +1900,9 @@ foo
.mtd1 { x =>
x + 1
}
.mtd2 { x =>
.mtd2: x =>
x + 1
x + 2
}
.mtd3 { x =>
x + 1
x + 2
Expand Down Expand Up @@ -1943,10 +1940,9 @@ foo
.mtd1 { x =>
x + 1
}
.mtd2 { x =>
.mtd2: x =>
x + 1
x + 2
}
.mtd3 { x =>
x + 1
x + 2
Expand Down Expand Up @@ -1977,10 +1973,9 @@ foo
.mtd1 { case x =>
x + 1
}
.mtd2 {
case x => x + 1
case y => y + 1
}
.mtd2:
case x => x + 1
case y => y + 1
.mtd3 {
case x => x + 1
case y => y + 1
Expand Down Expand Up @@ -2018,10 +2013,9 @@ foo
.mtd1 { case x =>
x + 1
}
.mtd2 {
case x => x + 1
case y => y + 1
}
.mtd2:
case x => x + 1
case y => y + 1
.mtd3 {
case x => x + 1
case y => y + 1
Expand Down Expand Up @@ -2059,11 +2053,10 @@ foo
bar match
case x => x + 1
}
.mtd2 {
bar match
case x => x + 1
case y => y + 1
}
.mtd2:
bar match
case x => x + 1
case y => y + 1
.mtd3 {
bar match
case x => x + 1
Expand Down Expand Up @@ -2106,12 +2099,11 @@ foo
def x =
x + 3
}
.mtd2 {
x + 1
def x =
x + 3
x + 4
}
.mtd2:
x + 1
def x =
x + 3
x + 4
.mtd3 {
x + 1
def x =
Expand Down

0 comments on commit fc0c020

Please sign in to comment.