Skip to content

Commit

Permalink
RemoveScala3OptionalBraces: handle fewer braces
Browse files Browse the repository at this point in the history
In some cases, invoke this rule from redundant-braces, since that rule
is guaranteed to run before remove-optional-braces.
  • Loading branch information
kitbellew committed Mar 28, 2024
1 parent 402432c commit 2b2a6d5
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 200 deletions.
Expand Up @@ -88,7 +88,7 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
): Option[(Replacement, Replacement)] = Option {
ft.right match {
case _: Token.RightBrace => onRightBrace(left)
case _: Token.RightParen => onRightParen(left)
case _: Token.RightParen => onRightParen(left, hasFormatOff)
case _ => null
}
}
Expand Down Expand Up @@ -142,14 +142,28 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
lpFunction.orElse(lpPartialFunction).orNull
}

private def onRightParen(left: Replacement)(implicit
private def onRightParen(left: Replacement, hasFormatOff: Boolean)(implicit
ft: FormatToken,
session: Session,
style: ScalafmtConfig
): (Replacement, Replacement) = left.how match {
case ReplacementType.Remove =>
val resOpt = getRightBraceBeforeRightParen(false).map { rb =>
// we'll use right brace later, when applying fewer-braces rewrite
ft.meta.rightOwner match {
case ac: Term.ArgClause =>
ftoks.matchingOpt(rb.left).map(ftoks.justBefore).foreach { lb =>
session.rule[RemoveScala3OptionalBraces].foreach { r =>
session.getClaimed(lb.meta.idx).foreach { case (leftIdx, _) =>
val repl = r.onLeftForArgClause(ac)(lb, left.style)
if (null ne repl) {
implicit val ft: FormatToken = ftoks.prev(rb)
repl.onRightAndClaim(hasFormatOff, leftIdx)
}
}
}
}
case _ =>
}
(left, removeToken)
}
resOpt.orNull
Expand All @@ -170,7 +184,16 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
new Token.RightBrace(rb.input, rb.dialect, rb.start + 1)
}
}
replaceIfAfterRightBrace.orNull // don't know how to Replace
(ft.meta.rightOwner match {
case ac: Term.ArgClause =>
session.rule[RemoveScala3OptionalBraces].flatMap { r =>
val repl = r.onLeftForArgClause(ac)(left.ft, left.style)
if (repl eq null) None else repl.onRight(hasFormatOff)
}
case _ => None
}).getOrElse {
replaceIfAfterRightBrace.orNull // don't know how to Replace
}
case _ => null
}

Expand Down
@@ -1,5 +1,7 @@
package org.scalafmt.rewrite

import scala.reflect.ClassTag

import scala.meta._
import scala.meta.tokens.Token

Expand Down Expand Up @@ -47,6 +49,18 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens)
if (t.parent.exists(_.is[Defn.Given])) removeToken
else
replaceToken(":")(new Token.Colon(x.input, x.dialect, x.start))
case t: Term.ArgClause => onLeftForArgClause(t)
case t: Term.PartialFunction =>
t.parent match {
case Some(p: Term.ArgClause) if (p.tokens.head match {
case px: Token.LeftBrace => px eq x
case px: Token.LeftParen =>
shouldRewriteArgClauseWithLeftParen[RedundantBraces](px)
case _ => false
}) =>
onLeftForArgClause(p)
case _ => null
}
case _: Term.For if allowOldSyntax || {
val rbFt = ftoks(ftoks.matching(ft.right))
ftoks.nextNonComment(rbFt).right.is[Token.KwDo]
Expand Down Expand Up @@ -82,9 +96,14 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens)
case _ => false
}
case _ => false
}) || (left.ft.right match {
case _: Token.Colon => !shouldRewriteColonOnRight(left)
case _ => false
})
ft.right match {
case _ if notOkToRewrite => None
case _: Token.RightParen if RewriteTrailingCommas.checkIfPrevious =>
Some((left, removeToken))
case x: Token.RightBrace =>
val replacement = ft.meta.rightOwner match {
case _: Term.For if allowOldSyntax && !nextFt.right.is[Token.KwDo] =>
Expand All @@ -96,9 +115,11 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens)
}
}

private def onLeftForBlock(
tree: Term.Block
)(implicit ft: FormatToken, style: ScalafmtConfig): Replacement =
private def onLeftForBlock(tree: Term.Block)(implicit
ft: FormatToken,
session: Session,
style: ScalafmtConfig
): Replacement =
tree.parent.fold(null: Replacement) {
case t: Term.If =>
val ok = ftoks.prevNonComment(ft).left match {
Expand Down Expand Up @@ -136,7 +157,94 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens)
else if (ftoks.prevNonComment(ft).left.is[Token.Equals]) removeToken
else null
case p: Tree.WithBody => if (p.body eq tree) removeToken else null
case p: Term.ArgClause =>
p.tokens.head match {
case _: Token.LeftBrace =>
onLeftForArgClause(p)
case px: Token.LeftParen
if shouldRewriteArgClauseWithLeftParen[RedundantParens](px) =>
onLeftForArgClause(p)
case _ => null
}
case _ => null
}

private def shouldRewriteArgClauseWithLeftParen[A <: Rule](
lp: Token
)(implicit ft: FormatToken, session: Session, tag: ClassTag[A]) = {
val prevFt = ftoks.prevNonComment(ft)
prevFt.left.eq(lp) && session
.claimedRule(prevFt.meta.idx - 1)
.exists(x => tag.runtimeClass.isInstance(x.rule))
}

private[rewrite] def onLeftForArgClause(
tree: Term.ArgClause
)(implicit ft: FormatToken, style: ScalafmtConfig): Replacement = {
val ok = style.dialect.allowFewerBraces &&
style.rewrite.scala3.removeOptionalBraces.fewerBracesMaxSpan > 0 &&
isSeqSingle(tree.values)
if (!ok) return null

tree.parent match {
case Some(p: Term.Apply) if (p.parent match {
case Some(pp: Term.Apply) => pp.fun ne p
case _ => true
}) =>
val x = ft.right // `{` or `(`
replaceToken(":")(new Token.Colon(x.input, x.dialect, x.start))
case _ => null
}
}

private def shouldRewriteColonOnRight(left: Replacement)(implicit
ft: FormatToken,
session: Session,
style: ScalafmtConfig
): Boolean = {
val lft = left.ft
lft.meta.rightOwner match {
case t: Term.ArgClause => shouldRewriteArgClauseColonOnRight(t, lft)
case t @ (_: Term.Block | _: Term.PartialFunction) =>
t.parent match {
case Some(p: Term.ArgClause) =>
shouldRewriteArgClauseColonOnRight(p, lft)
case _ => false
}
case _ => true // template etc
}
}

private def shouldRewriteArgClauseColonOnRight(
ac: Term.ArgClause,
lft: FormatToken
)(implicit
ft: FormatToken,
session: Session,
style: ScalafmtConfig
): Boolean = ac.values match {
case arg :: Nil =>
val begIdx = math.max(ftoks.getHead(arg).meta.idx - 1, lft.meta.idx + 1)
val endIdx = math.min(ftoks.getLast(arg).meta.idx, ft.meta.idx)
var span = 0
val rob = style.rewrite.scala3.removeOptionalBraces
val maxStats = rob.fewerBracesMaxSpan
(begIdx until endIdx).foreach { idx =>
val tokOpt = session.claimedRule(idx) match {
case Some(x) if x.ft.meta.idx == idx =>
if (x.how == ReplacementType.Remove) None
else Some(x.ft.right)
case _ =>
val tok = ftoks(idx).right
if (tok.is[Token.Whitespace]) None else Some(tok)
}
tokOpt.foreach { tok =>
span += tok.end - tok.start
if (span > maxStats) return false // RETURNING!!!
}
}
span >= rob.fewerBracesMinSpan
case _ => false
}

}
90 changes: 39 additions & 51 deletions scalafmt-tests/src/test/resources/scala3/FewerBraces.stat
Expand Up @@ -1827,10 +1827,9 @@ foo
.mtd1 {
x + 1
}
.mtd2 {
x + 1
x + 2
}
.mtd2:
x + 1
x + 2
.mtd3 {
x + 1
x + 2
Expand Down Expand Up @@ -1861,10 +1860,9 @@ foo
.mtd1 { x =>
x + 1
}
.mtd2 { x =>
.mtd2: x =>
x + 1
x + 2
}
.mtd3 { x =>
x + 1
x + 2
Expand Down Expand Up @@ -1902,10 +1900,9 @@ foo
.mtd1 { x =>
x + 1
}
.mtd2 { x =>
.mtd2: x =>
x + 1
x + 2
}
.mtd3 { x =>
x + 1
x + 2
Expand Down Expand Up @@ -1945,10 +1942,9 @@ foo
.mtd1 { x =>
x + 1
}
.mtd2 { x =>
.mtd2: x =>
x + 1
x + 2
}
.mtd3 { x =>
x + 1
x + 2
Expand Down Expand Up @@ -1988,10 +1984,9 @@ foo
.mtd1 { x =>
x + 1
}
.mtd2 { x =>
.mtd2: x =>
x + 1
x + 2
}
.mtd3 { x =>
x + 1
x + 2
Expand Down Expand Up @@ -2028,15 +2023,14 @@ object a:
mtd1 { x =>
x + 1
}
+ mtd2 { x =>
x + 1
x + 2
}
+ mtd3 { x =>
x + 1
x + 2
x + 3
}
+ mtd2: x =>
x + 1
x + 2
+ mtd3 { x =>
x + 1
x + 2
x + 3
}
<<< rewrite to fewer braces: func in parens and braces
rewrite.rules = [RedundantBraces, RedundantParens]
rewrite.scala3.removeOptionalBraces = {
Expand Down Expand Up @@ -2069,10 +2063,9 @@ foo
.mtd1 { x =>
x + 1
}
.mtd2 { x =>
.mtd2: x =>
x + 1
x + 2
}
.mtd3 { x =>
x + 1
x + 2
Expand Down Expand Up @@ -2109,15 +2102,14 @@ object a:
mtd1 { x =>
x + 1
}
+ mtd2 { x =>
x + 1
x + 2
}
+ mtd3 { x =>
x + 1
x + 2
x + 3
}
+ mtd2: x =>
x + 1
x + 2
+ mtd3 { x =>
x + 1
x + 2
x + 3
}
<<< rewrite to fewer braces: partial func
rewrite.scala3.removeOptionalBraces = {
enabled = yes
Expand All @@ -2143,10 +2135,9 @@ foo
.mtd1 { case x =>
x + 1
}
.mtd2 {
case x => x + 1
case y => y + 1
}
.mtd2:
case x => x + 1
case y => y + 1
.mtd3 {
case x => x + 1
case y => y + 1
Expand Down Expand Up @@ -2184,10 +2175,9 @@ foo
.mtd1 { case x =>
x + 1
}
.mtd2 {
case x => x + 1
case y => y + 1
}
.mtd2:
case x => x + 1
case y => y + 1
.mtd3 {
case x => x + 1
case y => y + 1
Expand Down Expand Up @@ -2225,11 +2215,10 @@ foo
bar match
case x => x + 1
}
.mtd2 {
bar match
case x => x + 1
case y => y + 1
}
.mtd2:
bar match
case x => x + 1
case y => y + 1
.mtd3 {
bar match
case x => x + 1
Expand Down Expand Up @@ -2272,12 +2261,11 @@ foo
def x =
x + 3
}
.mtd2 {
x + 1
def x =
x + 3
x + 4
}
.mtd2:
x + 1
def x =
x + 3
x + 4
.mtd3 {
x + 1
def x =
Expand Down

0 comments on commit 2b2a6d5

Please sign in to comment.