diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatToken.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatToken.scala index b8cf3bc541..102dfc5253 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatToken.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatToken.scala @@ -49,6 +49,10 @@ case class FormatToken(left: Token, right: Token, meta: FormatToken.Meta) { /** A format token is uniquely identified by its left token. */ override def hashCode(): Int = hash(left).## + + private[scalafmt] def withIdx(idx: Int): FormatToken = + copy(meta = meta.copy(idx = idx)) + } object FormatToken { diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatTokens.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatTokens.scala index 22c9bad16a..4be3a59a9a 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatTokens.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatTokens.scala @@ -39,10 +39,11 @@ class FormatTokens(leftTok2tok: Map[TokenHash, Int])( if (idx >= arr.length) arr.last else { val ft = arr(idx) - if (isBefore) { - if (ft.left.start <= tok.start) ft else at(idx - 1) + if (ft.left eq tok) ft + else if (isBefore) { + if (ft.left.start < tok.start) ft else at(idx - 1) } else { - if (ft.left.start >= tok.start) ft else at(idx + 1) + if (ft.left.start > tok.start) ft else at(idx + 1) } } } diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/FormatTokensRewrite.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/FormatTokensRewrite.scala index a54f443ef1..c6ae97a789 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/FormatTokensRewrite.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/FormatTokensRewrite.scala @@ -40,6 +40,8 @@ class FormatTokensRewrite( val tokenMap = Map.newBuilder[TokenOps.TokenHash, Int] tokenMap.sizeHint(arr.length) + val shiftedIndexMap = mutable.Map.empty[Int, Int] + var appended = 0 var removed = 0 def copySlice(end: Int): Unit = { @@ -52,15 +54,16 @@ class FormatTokensRewrite( val idx = ft.meta.idx val ftOld = arr(idx) val rtOld = ftOld.right - @inline def mapOld() = tokenMap += FormatTokens.thash(rtOld) -> appended + @inline def mapOld(dstidx: Int) = + tokenMap += FormatTokens.thash(rtOld) -> dstidx copySlice(idx) def append(): Unit = { - if (rtOld ne ft.right) mapOld() + if (rtOld ne ft.right) mapOld(appended) appended += 1 result += ft } - def remove(): Unit = { - mapOld() + def remove(dstidx: Int): Unit = { + mapOld(dstidx) val nextIdx = idx + 1 val nextFt = ftoks.at(nextIdx) val rtMeta = nextFt.meta @@ -69,7 +72,17 @@ class FormatTokensRewrite( } removed += 1 } - if (repl.how eq ReplacementType.Remove) remove() else append() + repl.how match { + case ReplacementType.Remove => remove(appended) + case ReplacementType.Replace => append() + case r: ReplacementType.RemoveAndResurrect => + if (r.idx == idx) { // we moved here + append() + shiftedIndexMap.put(idx, appended) + } else { // we moved from here + remove(shiftedIndexMap.remove(r.idx).getOrElse(appended)) + } + } } if (appended + removed == 0) ftoks @@ -232,20 +245,22 @@ object FormatTokensRewrite { protected final def replaceToken( text: String, owner: Option[Tree] = None, - claim: Iterable[Int] = Nil + claim: Iterable[Int] = Nil, + rtype: ReplacementType = ReplacementType.Replace )(tok: T)(implicit ft: FormatToken, style: ScalafmtConfig): Replacement = { val mOld = ft.meta.right val mNew = mOld.copy(text = text, owner = owner.getOrElse(mOld.owner)) val ftNew = ft.copy(right = tok, meta = ft.meta.copy(right = mNew)) - Replacement(this, ftNew, ReplacementType.Replace, style, claim) + Replacement(this, ftNew, rtype, style, claim) } protected final def replaceTokenBy( text: String, owner: Option[Tree] = None, - claim: Iterable[Int] = Nil + claim: Iterable[Int] = Nil, + rtype: ReplacementType = ReplacementType.Replace )(f: T => T)(implicit ft: FormatToken, style: ScalafmtConfig): Replacement = - replaceToken(text, owner, claim)(f(ft.right)) + replaceToken(text, owner, claim, rtype)(f(ft.right)) protected final def replaceTokenIdent(text: String, t: T)(implicit ft: FormatToken, @@ -297,7 +312,24 @@ object FormatTokensRewrite { private[rewrite] def claim(ftIdx: Int, repl: Replacement): Int = { val idx = tokens.length claimed.update(ftIdx, idx) - tokens.append(repl) + tokens.append( + if (repl eq null) null + else + (repl.how match { + case rt: ReplacementType.RemoveAndResurrect => + claimed.get(rt.idx).flatMap { oldidx => + val orepl = tokens(oldidx) + val ok = orepl != null && (orepl.rule eq repl.rule) && + (orepl.how eq ReplacementType.Remove) + if (ok) { + tokens(oldidx) = + repl.copy(ft = repl.ft.withIdx(orepl.ft.meta.idx)) + Some(repl.copy(ft = orepl.ft.withIdx(repl.ft.meta.idx))) + } else None + } + case _ => None + }).getOrElse(repl) + ) idx } @@ -343,6 +375,7 @@ object FormatTokensRewrite { private[rewrite] object ReplacementType { object Remove extends ReplacementType object Replace extends ReplacementType + class RemoveAndResurrect(val idx: Int) extends ReplacementType } private def mergeWhitespaceLeftToRight( diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala index 1894009cd7..bbbcc72f77 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala @@ -93,15 +93,6 @@ class RedundantBraces(implicit val ftoks: FormatTokens) } } - // we might not keep it but will hint to onRight - private def replaceWithLeftParen(implicit - ft: FormatToken, - style: ScalafmtConfig - ): Replacement = - replaceTokenBy("(") { x => - new Token.LeftParen(x.input, x.dialect, x.start) - } - private def replaceWithEquals(implicit ft: FormatToken, style: ScalafmtConfig @@ -122,7 +113,7 @@ class RedundantBraces(implicit val ftoks: FormatTokens) case b: Term.Block => ftoks.getHead(b) match { case FormatToken(_: Token.LeftBrace, _, lbm) => - replaceToken(lbm.left.text, Some(rtOwner), lbm.idx - 1 :: Nil) { + replaceToken("{", claim = lbm.idx - 1 :: Nil) { new Token.LeftBrace(rt.input, rt.dialect, rt.start) } case _ => null @@ -149,8 +140,27 @@ class RedundantBraces(implicit val ftoks: FormatTokens) private def onRightParen(left: Replacement)(implicit ft: FormatToken, style: ScalafmtConfig - ): (Replacement, Replacement) = - (left, removeToken) + ): (Replacement, Replacement) = left.how match { + case ReplacementType.Remove => (left, removeToken) + case ReplacementType.Replace if { + val lft = left.ft + val ro = ft.meta.rightOwner + (lft.meta.rightOwner eq ro) && + lft.right.is[Token.LeftBrace] && + okToReplaceFunctionInSingleArgApply(ro).isDefined + } => + val pft = ftoks.prevNonComment(ft) + val rb = pft.left + if (rb.is[Token.RightBrace]) { + // move right to the end of the function + val rType = new ReplacementType.RemoveAndResurrect(pft.meta.idx - 1) + left -> replaceToken("}", rtype = rType) { + // create a different token so that any child tree wouldn't own it + new Token.RightBrace(rb.input, rb.dialect, rb.start) + } + } else null // don't know how to Replace + case _ => null + } private def onLeftBrace(implicit ft: FormatToken, @@ -175,7 +185,6 @@ class RedundantBraces(implicit val ftoks: FormatTokens) owner match { case t: Term.FunctionTerm if t.tokens.last.is[Token.RightBrace] => if (!okToRemoveFunctionInApplyOrInit(t)) null - else if (okToReplaceFunctionInSingleArgApply(t)) replaceWithLeftParen else removeToken case t: Term.PartialFunction if t.parent.exists { p => SingleArgInBraces.orBlock(p).exists(_._2 eq t) && @@ -186,7 +195,7 @@ class RedundantBraces(implicit val ftoks: FormatTokens) t.parent match { case Some(f: Term.FunctionTerm) if okToReplaceFunctionInSingleArgApply(f) => - replaceWithLeftParen + removeToken case Some(_: Term.Interpolate) => handleInterpolation case _ => if (processBlock(t)) removeToken else null @@ -210,16 +219,7 @@ class RedundantBraces(implicit val ftoks: FormatTokens) ft: FormatToken, style: ScalafmtConfig ): (Replacement, Replacement) = - left.ft match { - case lft @ FormatToken(_, _: Token.LeftParen, _) - if left.how eq ReplacementType.Replace => - val right = replaceTokenBy("}", ft.meta.rightOwner.parent) { rt => - // shifted right - new Token.RightBrace(rt.input, rt.dialect, rt.start + 1) - } - (removeToken(lft, style), right) - case _ => (left, removeToken) - } + (left, removeToken) private def settings(implicit style: ScalafmtConfig