Skip to content

Commit

Permalink
FormatTokensRewrite: move Rule into Replacement
Browse files Browse the repository at this point in the history
That way, within one rule we'll be able to invoke a different one, and
the replacement value will indicate which rule actually processed the
substitution.
  • Loading branch information
kitbellew committed Mar 10, 2024
1 parent 9f0e6a5 commit 8c00f0e
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 60 deletions.
Expand Up @@ -133,7 +133,7 @@ class FormatTokensRewrite(
if (formatOff) None
else
session.claimedRule match {
case Some(c) => if (applyRule(c.rule)) Some(c.rule) else None
case Some(c) => applyRule(c.rule)
case _ => applyRules
}
leftDelimIndex.prepend((ldelimIdx, ruleOpt))
Expand Down Expand Up @@ -189,7 +189,7 @@ class FormatTokensRewrite(
private def applyRule(rule: Rule)(implicit
ft: FormatToken,
session: Session
): Boolean = {
): Option[Rule] = {
implicit val style = styleMap.at(ft.right)
session.applyRule(rule)
}
Expand Down Expand Up @@ -219,6 +219,34 @@ object FormatTokensRewrite {
session: Session,
style: ScalafmtConfig
): Option[(Replacement, Replacement)]

protected final def removeToken(implicit ft: FormatToken): Replacement =
new Replacement(this, ft, ReplacementType.Remove)

protected final def replaceToken(
text: String,
owner: Option[Tree] = None,
claim: Iterable[Int] = Nil
)(tok: T)(implicit ft: FormatToken): Replacement = {
val mOld = ft.meta.right
val mNew = mOld.copy(text = text, owner = owner.getOrElse(mOld.owner))
val ftNew = ft.copy(right = tok, meta = ft.meta.copy(right = mNew))
new Replacement(this, ftNew, ReplacementType.Replace, claim)
}

protected final def replaceTokenBy(
text: String,
owner: Option[Tree] = None,
claim: Iterable[Int] = Nil
)(f: T => T)(implicit ft: FormatToken): Replacement =
replaceToken(text, owner, claim)(f(ft.right))

protected final def replaceTokenIdent(text: String, t: T)(implicit
ft: FormatToken
): Replacement = replaceToken(text)(
new T.Ident(t.input, t.dialect, t.start, t.start + text.length, text)
)

}

private[rewrite] trait RuleFactory {
Expand All @@ -245,37 +273,39 @@ object FormatTokensRewrite {

private[rewrite] class Session(rules: Seq[Rule]) {
private implicit val implicitSession: Session = this
private val claimed = new mutable.HashMap[Int, Claimant]()
private val claimed = new mutable.HashMap[Int, Replacement]()
private[FormatTokensRewrite] val tokens =
new mutable.ArrayBuffer[Replacement]()

@inline
def claimedRule(implicit ft: FormatToken): Option[Claimant] =
def claimedRule(implicit ft: FormatToken): Option[Replacement] =
claimedRule(ft.meta.idx)

@inline
private[rewrite] def claimedRule(ftIdx: Int): Option[Claimant] =
private[rewrite] def claimedRule(ftIdx: Int): Option[Replacement] =
claimed.get(ftIdx)

private[FormatTokensRewrite] def applyRule(
rule: Rule
)(implicit ft: FormatToken, style: ScalafmtConfig): Boolean =
rule.enabled && (rule.onToken match {
case Some(repl) =>
val claimant = new Claimant(rule, repl)
claimed.getOrElseUpdate(ft.meta.idx, claimant)
repl.claim.foreach { claimed.getOrElseUpdate(_, claimant) }
tokens.append(repl)
true
case _ => false
})
attemptedRule: Rule
)(implicit ft: FormatToken, style: ScalafmtConfig): Option[Rule] =
if (attemptedRule.enabled) attemptedRule.onToken.map { repl =>
claimed.getOrElseUpdate(ft.meta.idx, repl)
repl.claim.foreach { claimed.getOrElseUpdate(_, repl) }
tokens.append(repl)
repl.rule
}
else None

private[FormatTokensRewrite] def applyRules(
rules: Seq[Rule]
)(implicit ft: FormatToken, style: ScalafmtConfig): Option[Rule] = {
@tailrec
def iter(remainingRules: Seq[Rule]): Option[Rule] = remainingRules match {
case r +: rs => if (applyRule(r)) Some(r) else iter(rs)
case r +: rs =>
applyRule(r) match {
case None => iter(rs)
case x => x
}
case _ => None
}
iter(rules)
Expand All @@ -285,12 +315,8 @@ object FormatTokensRewrite {
rules.find(tag.runtimeClass.isInstance).map(_.asInstanceOf[A])
}

private[rewrite] class Claimant(
val rule: Rule,
val replacement: Replacement
)

private[rewrite] class Replacement(
val rule: Rule,
val ft: FormatToken,
val how: ReplacementType,
// list of FormatToken indices, with the claimed token on the **right**
Expand All @@ -303,33 +329,6 @@ object FormatTokensRewrite {
object Replace extends ReplacementType
}

private[rewrite] def removeToken(implicit ft: FormatToken): Replacement =
new Replacement(ft, ReplacementType.Remove)

private[rewrite] def replaceToken(
text: String,
owner: Option[Tree] = None,
claim: Iterable[Int] = Nil
)(tok: T)(implicit ft: FormatToken): Replacement = {
val mOld = ft.meta.right
val mNew = mOld.copy(text = text, owner = owner.getOrElse(mOld.owner))
val ftNew = ft.copy(right = tok, meta = ft.meta.copy(right = mNew))
new Replacement(ftNew, ReplacementType.Replace, claim)
}

private[rewrite] def replaceTokenBy(
text: String,
owner: Option[Tree] = None,
claim: Iterable[Int] = Nil
)(f: T => T)(implicit ft: FormatToken): Replacement =
replaceToken(text, owner, claim)(f(ft.right))

private[rewrite] def replaceTokenIdent(text: String, t: T)(implicit
ft: FormatToken
): Replacement = replaceToken(text)(
new T.Ident(t.input, t.dialect, t.start, t.start + text.length, text)
)

private def mergeWhitespaceLeftToRight(
lt: FormatToken.Meta,
rt: FormatToken.Meta
Expand Down
Expand Up @@ -56,17 +56,6 @@ object RedundantBraces extends Rewrite with FormatTokensRewrite.RuleFactory {
case x => nested || x.isDefined
})

// we might not keep it but will hint to onRight
private def replaceWithLeftParen(implicit ft: FormatToken): Replacement =
replaceTokenBy("(") { x =>
new Token.LeftParen(x.input, x.dialect, x.start)
}

private def replaceWithEquals(implicit ft: FormatToken): Replacement =
replaceTokenBy("=") { x =>
new Token.Equals(x.input, x.dialect, x.start)
}

}

/** Removes/adds curly braces where desired.
Expand Down Expand Up @@ -104,6 +93,17 @@ class RedundantBraces(implicit val ftoks: FormatTokens)
}
}

// we might not keep it but will hint to onRight
private def replaceWithLeftParen(implicit ft: FormatToken): Replacement =
replaceTokenBy("(") { x =>
new Token.LeftParen(x.input, x.dialect, x.start)
}

private def replaceWithEquals(implicit ft: FormatToken): Replacement =
replaceTokenBy("=") { x =>
new Token.Equals(x.input, x.dialect, x.start)
}

private def onLeftParen(implicit
ft: FormatToken,
style: ScalafmtConfig
Expand Down

0 comments on commit 8c00f0e

Please sign in to comment.