Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RemoveScala3OptionalBraces: handle fewer braces #3815

Merged
merged 2 commits into from Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/configuration.md
Expand Up @@ -3698,6 +3698,13 @@ The section contains the following settings (available since v3.8.1):
- other flags below might extend rewrites to other cases
- `oldSyntaxToo`
- if `true`, applies also to expressions using deprecated syntax
- `fewerBracesMinSpan` and `fewerBracesMaxSpan`
- will apply the rewrite to last curried single-argument group if
it is enclosed in curly braces (or would be rewritten to curly
braces by the `RedundantBraces` rule)
- will only apply the rewrite if the cumulative span of all visible
(non-whitespace) tokens within the argument is between the two values
- this rule is disabled if `fewerBracesMaxSpan == 0`

Prior to v3.8.1, `rewrite.scala3.removeOptionalBraces` was a flag which
took three possible values (with their equivalent current settings shown):
Expand Down
Expand Up @@ -34,6 +34,8 @@ object RewriteScala3Settings {

case class RemoveOptionalBraces(
enabled: Boolean = true,
fewerBracesMinSpan: Int = 2,
fewerBracesMaxSpan: Int = 0,
oldSyntaxToo: Boolean = false
)

Expand Down
Expand Up @@ -457,6 +457,11 @@ object ScalafmtConfig {
addIf(rewrite.insertBraces.minLines < rewrite.redundantBraces.maxBreaks)
addIf(align.beforeOpenParenDefnSite && !align.closeParenSite)
addIf(align.beforeOpenParenCallSite && !align.closeParenSite)
addIf(rewrite.scala3.removeOptionalBraces.fewerBracesMinSpan <= 0)
if (rewrite.scala3.removeOptionalBraces.fewerBracesMaxSpan != 0) {
addIf(rewrite.scala3.removeOptionalBraces.fewerBracesMaxSpan < 0)
addIf(rewrite.scala3.removeOptionalBraces.fewerBracesMinSpan > rewrite.scala3.removeOptionalBraces.fewerBracesMaxSpan)
}
}
// scalafmt: {}
if (allErrors.isEmpty) Configured.ok(cfg)
Expand Down
Expand Up @@ -88,7 +88,7 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
): Option[(Replacement, Replacement)] = Option {
ft.right match {
case _: Token.RightBrace => onRightBrace(left)
case _: Token.RightParen => onRightParen(left)
case _: Token.RightParen => onRightParen(left, hasFormatOff)
case _ => null
}
}
Expand Down Expand Up @@ -142,14 +142,28 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
lpFunction.orElse(lpPartialFunction).orNull
}

private def onRightParen(left: Replacement)(implicit
private def onRightParen(left: Replacement, hasFormatOff: Boolean)(implicit
ft: FormatToken,
session: Session,
style: ScalafmtConfig
): (Replacement, Replacement) = left.how match {
case ReplacementType.Remove =>
val resOpt = getRightBraceBeforeRightParen(false).map { rb =>
// we'll use right brace later, when applying fewer-braces rewrite
ft.meta.rightOwner match {
case ac: Term.ArgClause =>
ftoks.matchingOpt(rb.left).map(ftoks.justBefore).foreach { lb =>
session.rule[RemoveScala3OptionalBraces].foreach { r =>
session.getClaimed(lb.meta.idx).foreach { case (leftIdx, _) =>
val repl = r.onLeftForArgClause(ac)(lb, left.style)
if (null ne repl) {
implicit val ft: FormatToken = ftoks.prev(rb)
repl.onRightAndClaim(hasFormatOff, leftIdx)
}
}
}
}
case _ =>
}
(left, removeToken)
}
resOpt.orNull
Expand All @@ -170,7 +184,16 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
new Token.RightBrace(rb.input, rb.dialect, rb.start + 1)
}
}
replaceIfAfterRightBrace.orNull // don't know how to Replace
(ft.meta.rightOwner match {
case ac: Term.ArgClause =>
session.rule[RemoveScala3OptionalBraces].flatMap { r =>
val repl = r.onLeftForArgClause(ac)(left.ft, left.style)
if (repl eq null) None else repl.onRight(hasFormatOff)
}
case _ => None
}).getOrElse {
replaceIfAfterRightBrace.orNull // don't know how to Replace
}
case _ => null
}

Expand Down
@@ -1,12 +1,14 @@
package org.scalafmt.rewrite

import scala.reflect.ClassTag

import scala.meta._
import scala.meta.tokens.Token

import org.scalafmt.config.ScalafmtConfig
import org.scalafmt.internal.FormatToken
import org.scalafmt.internal.FormatTokens
import org.scalafmt.util.TreeOps
import org.scalafmt.util.TreeOps._

object RemoveScala3OptionalBraces extends FormatTokensRewrite.RuleFactory {

Expand Down Expand Up @@ -47,6 +49,18 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens)
if (t.parent.exists(_.is[Defn.Given])) removeToken
else
replaceToken(":")(new Token.Colon(x.input, x.dialect, x.start))
case t: Term.ArgClause => onLeftForArgClause(t)
case t: Term.PartialFunction =>
t.parent match {
case Some(p: Term.ArgClause) if (p.tokens.head match {
case px: Token.LeftBrace => px eq x
case px: Token.LeftParen =>
shouldRewriteArgClauseWithLeftParen[RedundantBraces](px)
case _ => false
}) =>
onLeftForArgClause(p)
case _ => null
}
case _: Term.For if allowOldSyntax || {
val rbFt = ftoks(ftoks.matching(ft.right))
ftoks.nextNonComment(rbFt).right.is[Token.KwDo]
Expand Down Expand Up @@ -82,9 +96,14 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens)
case _ => false
}
case _ => false
}) || (left.ft.right match {
case _: Token.Colon => !shouldRewriteColonOnRight(left)
case _ => false
})
ft.right match {
case _ if notOkToRewrite => None
case _: Token.RightParen if RewriteTrailingCommas.checkIfPrevious =>
Some((left, removeToken))
case x: Token.RightBrace =>
val replacement = ft.meta.rightOwner match {
case _: Term.For if allowOldSyntax && !nextFt.right.is[Token.KwDo] =>
Expand All @@ -96,16 +115,18 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens)
}
}

private def onLeftForBlock(
tree: Term.Block
)(implicit ft: FormatToken, style: ScalafmtConfig): Replacement =
private def onLeftForBlock(tree: Term.Block)(implicit
ft: FormatToken,
session: Session,
style: ScalafmtConfig
): Replacement =
tree.parent.fold(null: Replacement) {
case t: Term.If =>
val ok = ftoks.prevNonComment(ft).left match {
case _: Token.KwIf => true
case _: Token.KwThen => true
case _: Token.KwElse =>
!TreeOps.isTreeMultiStatBlock(t.elsep) ||
!isTreeMultiStatBlock(t.elsep) ||
ftoks.tokenAfter(t.cond).right.is[Token.KwThen]
case _: Token.RightParen => allowOldSyntax
case _ => false
Expand Down Expand Up @@ -136,7 +157,94 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens)
else if (ftoks.prevNonComment(ft).left.is[Token.Equals]) removeToken
else null
case p: Tree.WithBody => if (p.body eq tree) removeToken else null
case p: Term.ArgClause =>
p.tokens.head match {
case _: Token.LeftBrace =>
onLeftForArgClause(p)
case px: Token.LeftParen
if shouldRewriteArgClauseWithLeftParen[RedundantParens](px) =>
onLeftForArgClause(p)
case _ => null
}
case _ => null
}

private def shouldRewriteArgClauseWithLeftParen[A <: Rule](
lp: Token
)(implicit ft: FormatToken, session: Session, tag: ClassTag[A]) = {
val prevFt = ftoks.prevNonComment(ft)
prevFt.left.eq(lp) && session
.claimedRule(prevFt.meta.idx - 1)
.exists(x => tag.runtimeClass.isInstance(x.rule))
}

private[rewrite] def onLeftForArgClause(
tree: Term.ArgClause
)(implicit ft: FormatToken, style: ScalafmtConfig): Replacement = {
val ok = style.dialect.allowFewerBraces &&
style.rewrite.scala3.removeOptionalBraces.fewerBracesMaxSpan > 0 &&
isSeqSingle(tree.values)
if (!ok) return null

tree.parent match {
case Some(p: Term.Apply) if (p.parent match {
case Some(pp: Term.Apply) => pp.fun ne p
case _ => true
}) =>
val x = ft.right // `{` or `(`
replaceToken(":")(new Token.Colon(x.input, x.dialect, x.start))
case _ => null
}
}

private def shouldRewriteColonOnRight(left: Replacement)(implicit
ft: FormatToken,
session: Session,
style: ScalafmtConfig
): Boolean = {
val lft = left.ft
lft.meta.rightOwner match {
case t: Term.ArgClause => shouldRewriteArgClauseColonOnRight(t, lft)
case t @ (_: Term.Block | _: Term.PartialFunction) =>
t.parent match {
case Some(p: Term.ArgClause) =>
shouldRewriteArgClauseColonOnRight(p, lft)
case _ => false
}
case _ => true // template etc
}
}

private def shouldRewriteArgClauseColonOnRight(
ac: Term.ArgClause,
lft: FormatToken
)(implicit
ft: FormatToken,
session: Session,
style: ScalafmtConfig
): Boolean = ac.values match {
case arg :: Nil =>
val begIdx = math.max(ftoks.getHead(arg).meta.idx - 1, lft.meta.idx + 1)
val endIdx = math.min(ftoks.getLast(arg).meta.idx, ft.meta.idx)
var span = 0
val rob = style.rewrite.scala3.removeOptionalBraces
val maxStats = rob.fewerBracesMaxSpan
(begIdx until endIdx).foreach { idx =>
val tokOpt = session.claimedRule(idx) match {
case Some(x) if x.ft.meta.idx == idx =>
if (x.how == ReplacementType.Remove) None
else Some(x.ft.right)
case _ =>
val tok = ftoks(idx).right
if (tok.is[Token.Whitespace]) None else Some(tok)
}
tokOpt.foreach { tok =>
span += tok.end - tok.start
if (span > maxStats) return false // RETURNING!!!
}
}
span >= rob.fewerBracesMinSpan
case _ => false
}

}