Skip to content

Commit

Permalink
RedundantBraces: fix "moving" braces around func
Browse files Browse the repository at this point in the history
Previously, we'd handle the case of `foo(x => { ... })` by:
1. replacing the `(` with a `{`
2. "pretending" to replace the `{` with a `(`
3. upon encountering the `}` for a fake `(`, removing the `(` and
   changing on the owner of the `}`
4. then removing the `)`

Instead, let's modify steps 2, 3 and 4 as follows:
2. remove `{`
3. remove `}`
4. replace the closing `)` with a `}` moving it to the position of the
   `}` in step 3, along with token ownership.
  • Loading branch information
kitbellew committed Mar 16, 2024
1 parent adf6c79 commit 06d3195
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 37 deletions.
Expand Up @@ -49,6 +49,10 @@ case class FormatToken(left: Token, right: Token, meta: FormatToken.Meta) {
/** A format token is uniquely identified by its left token.
*/
override def hashCode(): Int = hash(left).##

private[scalafmt] def withIdx(idx: Int): FormatToken =
copy(meta = meta.copy(idx = idx))

}

object FormatToken {
Expand Down
Expand Up @@ -39,10 +39,11 @@ class FormatTokens(leftTok2tok: Map[TokenHash, Int])(
if (idx >= arr.length) arr.last
else {
val ft = arr(idx)
if (isBefore) {
if (ft.left.start <= tok.start) ft else at(idx - 1)
if (ft.left eq tok) ft
else if (isBefore) {
if (ft.left.start < tok.start) ft else at(idx - 1)
} else {
if (ft.left.start >= tok.start) ft else at(idx + 1)
if (ft.left.start > tok.start) ft else at(idx + 1)
}
}
}
Expand Down
Expand Up @@ -40,6 +40,8 @@ class FormatTokensRewrite(
val tokenMap = Map.newBuilder[TokenOps.TokenHash, Int]
tokenMap.sizeHint(arr.length)

val shiftedIndexMap = mutable.Map.empty[Int, Int]

var appended = 0
var removed = 0
def copySlice(end: Int): Unit = {
Expand All @@ -52,15 +54,16 @@ class FormatTokensRewrite(
val idx = ft.meta.idx
val ftOld = arr(idx)
val rtOld = ftOld.right
@inline def mapOld() = tokenMap += FormatTokens.thash(rtOld) -> appended
@inline def mapOld(dstidx: Int) =
tokenMap += FormatTokens.thash(rtOld) -> dstidx
copySlice(idx)
def append(): Unit = {
if (rtOld ne ft.right) mapOld()
if (rtOld ne ft.right) mapOld(appended)
appended += 1
result += ft
}
def remove(): Unit = {
mapOld()
def remove(dstidx: Int): Unit = {
mapOld(dstidx)
val nextIdx = idx + 1
val nextFt = ftoks.at(nextIdx)
val rtMeta = nextFt.meta
Expand All @@ -69,7 +72,17 @@ class FormatTokensRewrite(
}
removed += 1
}
if (repl.how eq ReplacementType.Remove) remove() else append()
repl.how match {
case ReplacementType.Remove => remove(appended)
case ReplacementType.Replace => append()
case r: ReplacementType.RemoveAndResurrect =>
if (r.idx == idx) { // we moved here
append()
shiftedIndexMap.put(idx, appended)
} else { // we moved from here
remove(shiftedIndexMap.remove(r.idx).getOrElse(appended))
}
}
}

if (appended + removed == 0) ftoks
Expand Down Expand Up @@ -232,20 +245,22 @@ object FormatTokensRewrite {
protected final def replaceToken(
text: String,
owner: Option[Tree] = None,
claim: Iterable[Int] = Nil
claim: Iterable[Int] = Nil,
rtype: ReplacementType = ReplacementType.Replace
)(tok: T)(implicit ft: FormatToken, style: ScalafmtConfig): Replacement = {
val mOld = ft.meta.right
val mNew = mOld.copy(text = text, owner = owner.getOrElse(mOld.owner))
val ftNew = ft.copy(right = tok, meta = ft.meta.copy(right = mNew))
Replacement(this, ftNew, ReplacementType.Replace, style, claim)
Replacement(this, ftNew, rtype, style, claim)
}

protected final def replaceTokenBy(
text: String,
owner: Option[Tree] = None,
claim: Iterable[Int] = Nil
claim: Iterable[Int] = Nil,
rtype: ReplacementType = ReplacementType.Replace
)(f: T => T)(implicit ft: FormatToken, style: ScalafmtConfig): Replacement =
replaceToken(text, owner, claim)(f(ft.right))
replaceToken(text, owner, claim, rtype)(f(ft.right))

protected final def replaceTokenIdent(text: String, t: T)(implicit
ft: FormatToken,
Expand Down Expand Up @@ -297,7 +312,24 @@ object FormatTokensRewrite {
private[rewrite] def claim(ftIdx: Int, repl: Replacement): Int = {
val idx = tokens.length
claimed.update(ftIdx, idx)
tokens.append(repl)
tokens.append(
if (repl eq null) null
else
(repl.how match {
case rt: ReplacementType.RemoveAndResurrect =>
claimed.get(rt.idx).flatMap { oldidx =>
val orepl = tokens(oldidx)
val ok = orepl != null && (orepl.rule eq repl.rule) &&
(orepl.how eq ReplacementType.Remove)
if (ok) {
tokens(oldidx) =
repl.copy(ft = repl.ft.withIdx(orepl.ft.meta.idx))
Some(repl.copy(ft = orepl.ft.withIdx(repl.ft.meta.idx)))
} else None
}
case _ => None
}).getOrElse(repl)
)
idx
}

Expand Down Expand Up @@ -343,6 +375,7 @@ object FormatTokensRewrite {
private[rewrite] object ReplacementType {
object Remove extends ReplacementType
object Replace extends ReplacementType
class RemoveAndResurrect(val idx: Int) extends ReplacementType
}

private def mergeWhitespaceLeftToRight(
Expand Down
Expand Up @@ -93,15 +93,6 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
}
}

// we might not keep it but will hint to onRight
private def replaceWithLeftParen(implicit
ft: FormatToken,
style: ScalafmtConfig
): Replacement =
replaceTokenBy("(") { x =>
new Token.LeftParen(x.input, x.dialect, x.start)
}

private def replaceWithEquals(implicit
ft: FormatToken,
style: ScalafmtConfig
Expand All @@ -122,7 +113,7 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
case b: Term.Block =>
ftoks.getHead(b) match {
case FormatToken(_: Token.LeftBrace, _, lbm) =>
replaceToken(lbm.left.text, Some(rtOwner), lbm.idx - 1 :: Nil) {
replaceToken("{", claim = lbm.idx - 1 :: Nil) {
new Token.LeftBrace(rt.input, rt.dialect, rt.start)
}
case _ => null
Expand All @@ -149,8 +140,27 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
private def onRightParen(left: Replacement)(implicit
ft: FormatToken,
style: ScalafmtConfig
): (Replacement, Replacement) =
(left, removeToken)
): (Replacement, Replacement) = left.how match {
case ReplacementType.Remove => (left, removeToken)
case ReplacementType.Replace if {
val lft = left.ft
val ro = ft.meta.rightOwner
(lft.meta.rightOwner eq ro) &&
lft.right.is[Token.LeftBrace] &&
okToReplaceFunctionInSingleArgApply(ro).isDefined
} =>
val pft = ftoks.prevNonComment(ft)
val rb = pft.left
if (rb.is[Token.RightBrace]) {
// move right to the end of the function
val rType = new ReplacementType.RemoveAndResurrect(pft.meta.idx - 1)
left -> replaceToken("}", rtype = rType) {
// create a different token so that any child tree wouldn't own it
new Token.RightBrace(rb.input, rb.dialect, rb.start)
}
} else null // don't know how to Replace
case _ => null
}

private def onLeftBrace(implicit
ft: FormatToken,
Expand All @@ -175,7 +185,6 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
owner match {
case t: Term.FunctionTerm if t.tokens.last.is[Token.RightBrace] =>
if (!okToRemoveFunctionInApplyOrInit(t)) null
else if (okToReplaceFunctionInSingleArgApply(t)) replaceWithLeftParen
else removeToken
case t: Term.PartialFunction if t.parent.exists { p =>
SingleArgInBraces.orBlock(p).exists(_._2 eq t) &&
Expand All @@ -186,7 +195,7 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
t.parent match {
case Some(f: Term.FunctionTerm)
if okToReplaceFunctionInSingleArgApply(f) =>
replaceWithLeftParen
removeToken
case Some(_: Term.Interpolate) => handleInterpolation
case _ =>
if (processBlock(t)) removeToken else null
Expand All @@ -210,16 +219,7 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
ft: FormatToken,
style: ScalafmtConfig
): (Replacement, Replacement) =
left.ft match {
case lft @ FormatToken(_, _: Token.LeftParen, _)
if left.how eq ReplacementType.Replace =>
val right = replaceTokenBy("}", ft.meta.rightOwner.parent) { rt =>
// shifted right
new Token.RightBrace(rt.input, rt.dialect, rt.start + 1)
}
(removeToken(lft, style), right)
case _ => (left, removeToken)
}
(left, removeToken)

private def settings(implicit
style: ScalafmtConfig
Expand Down

0 comments on commit 06d3195

Please sign in to comment.