diff --git a/scalameta/parsers/shared/src/main/scala/scala/meta/internal/parsers/ScalametaParser.scala b/scalameta/parsers/shared/src/main/scala/scala/meta/internal/parsers/ScalametaParser.scala index 3039b4432f..3776c3e3a4 100644 --- a/scalameta/parsers/shared/src/main/scala/scala/meta/internal/parsers/ScalametaParser.scala +++ b/scalameta/parsers/shared/src/main/scala/scala/meta/internal/parsers/ScalametaParser.scala @@ -544,7 +544,6 @@ class ScalametaParser(input: Input)(implicit dialect: Dialect) { parser => case Ident(x) => pred(x) case _ => false } - def isUnaryOp: Boolean = isIdentAnd(token, _.isUnaryOp) def isIdentExcept(except: String) = isIdentAnd(token, _ != except) def isIdentOf(tok: Token, name: String) = isIdentAnd(tok, _ == name) @inline def isStar: Boolean = isStar(token) @@ -1113,8 +1112,8 @@ class ScalametaParser(input: Input)(implicit dialect: Dialect) { parser => case _: Literal => if (dialect.allowLiteralTypes) literal() else syntaxError(s"$dialect doesn't support literal types", at = path()) - case Ident("-") if dialect.allowLiteralTypes && tryAhead[NumericConstant[_]] => - numericLiteral(prevTokenPos, isNegated = true) + case Unary.Numeric(unary) if dialect.allowLiteralTypes && tryAhead[NumericConstant[_]] => + numericLiteral(prevTokenPos, unary) case _ => pathSimpleType() } simpleTypeRest(autoEndPosOpt(startPos)(res), startPos) @@ -1332,29 +1331,30 @@ class ScalametaParser(input: Input)(implicit dialect: Dialect) { parser => if (acceptOpt[Dot]) selectors(name) else name } - private def numericLiteral(startPos: Int, isNegated: Boolean): Lit = { + private def numericLiteral(startPos: Int, unary: Unary.Numeric): Lit = { val number = token.asInstanceOf[NumericConstant[_]] next() - autoEndPos(startPos)(numericLiteralAt(number, isNegated)) + autoEndPos(startPos)(numericLiteralAt(number, unary)) } - private def numericLiteralAt(token: NumericConstant[_], isNegated: Boolean): Lit = { + private def numericLiteralAt(token: NumericConstant[_], unary: Unary.Numeric): Lit = { def getBigInt(tok: NumericConstant[BigInt], dec: BigInt, hex: BigInt, typ: String) = { // decimal never starts with `0` as octal was removed in 2.11; "hex" includes `0x` or `0b` // non-decimal literals allow signed overflow within unsigned range val max = if (tok.text(0) != '0') dec else hex // token value is always positive as it doesn't take into account a sign val value = tok.value - if (isNegated) { - if (value > max) syntaxError(s"integer number too small for $typ", at = token) - -value + val result = unary(value) + if (result.signum < 0) { + if (value > max) syntaxError(s"integer number too small for $typ", at = tok) } else { - if (value >= max) syntaxError(s"integer number too large for $typ", at = token) - value + if (value >= max) syntaxError(s"integer number too large for $typ", at = tok) } + result } def getBigDecimal(tok: NumericConstant[BigDecimal]) = - if (isNegated) -tok.value else tok.value + unary(tok.value) + .getOrElse(syntaxError(s"bad unary op `${unary.op}` for floating-point", at = tok)) token match { case tok: Constant.Int => Lit.Int(getBigInt(tok, bigIntMaxInt, bigIntMaxUInt, "Int").intValue) @@ -1369,7 +1369,7 @@ class ScalametaParser(input: Input)(implicit dialect: Dialect) { parser => } def literal(): Lit = atCurPosNext(token match { - case number: NumericConstant[_] => numericLiteralAt(number, false) + case number: NumericConstant[_] => numericLiteralAt(number, Unary.Noop) case Constant.Char(value) => Lit.Char(value) case Constant.String(value) => Lit.String(value) case t: Constant.Symbol => @@ -2228,16 +2228,14 @@ class ScalametaParser(input: Input)(implicit dialect: Dialect) { parser => } } - def prefixExpr(allowRepeated: Boolean): Term = - if (!isUnaryOp) simpleExpr(allowRepeated) - else { + def prefixExpr(allowRepeated: Boolean): Term = token match { + case Unary((ident, unary)) => val startPos = tokenPos - val op = termName() + next() + def op = atPos(startPos)(Term.Name(ident)) def addPos(tree: Term) = autoEndPos(startPos)(tree) def rest(tree: Term) = simpleExprRest(tree, canApply = true, startPos = startPos) - if (op.value == "-" && token.is[NumericConstant[_]]) - rest(numericLiteral(startPos, isNegated = true)) - else { + def otherwise = simpleExpr0(allowRepeated = true) match { case Success(result) => addPos(Term.ApplyUnary(op, result)) case Failure(_) => @@ -2245,8 +2243,15 @@ class ScalametaParser(input: Input)(implicit dialect: Dialect) { parser => // we would fail here anyway, let's try to treat it as ident rest(op) } + (token, unary) match { + case (tok: NumericConstant[_], unary: Unary.Numeric) => + next(); rest(addPos(numericLiteralAt(tok, unary))) + case (tok: BooleanConstant, unary: Unary.Logical) => + next(); rest(addPos(Lit.Boolean(unary(tok.value)))) + case _ => otherwise } - } + case _ => simpleExpr(allowRepeated) + } def simpleExpr(allowRepeated: Boolean): Term = simpleExpr0(allowRepeated).get @@ -2910,11 +2915,10 @@ class ScalametaParser(input: Input)(implicit dialect: Dialect) { parser => autoEndPos(startPos)(token match { case sidToken @ (_: Ident | _: KwThis | _: Unquote) => val sid = stableId() - if (token.is[NumericConstant[_]]) { - sid match { - case Term.Name("-") => return numericLiteral(startPos, isNegated = true) - case _ => - } + (token, sidToken) match { + case (_: NumericConstant[_], Unary.Numeric(unary)) if prevTokenPos == startPos => + return numericLiteral(startPos, unary) + case _ => } val targs = if (token.is[LeftBracket]) Some(super.patternTypeArgs()) else None if (token.is[LeftParen]) { diff --git a/scalameta/trees/shared/src/main/scala/scala/meta/internal/trees/Unary.scala b/scalameta/trees/shared/src/main/scala/scala/meta/internal/trees/Unary.scala new file mode 100644 index 0000000000..7f4a68b277 --- /dev/null +++ b/scalameta/trees/shared/src/main/scala/scala/meta/internal/trees/Unary.scala @@ -0,0 +1,63 @@ +package scala.meta.internal.trees + +import scala.meta.tokens.Token + +private[meta] sealed trait Unary { + def op: String +} + +private[meta] object Unary { + + private val numericOpMap = Seq[Numeric](Plus, Minus, Tilde).map(x => x.op -> x).toMap + val opMap = numericOpMap ++ Seq[Unary](Not).map(x => x.op -> x) + + def unapply(token: Token.Ident): Option[(String, Unary)] = { + val op = token.text + opMap.get(op).map(op -> _) + } + + sealed trait Numeric extends Unary { + def apply(value: BigInt): BigInt + // could return None if not applicable (such as `~`) + def apply(value: BigDecimal): Option[BigDecimal] + } + + object Numeric { + def unapply(token: Token.Ident): Option[Numeric] = + numericOpMap.get(token.text) + } + + sealed trait Logical extends Unary { + def apply(value: Boolean): Boolean + } + + case object Noop extends Numeric { + val op = "" + def apply(value: BigInt): BigInt = value + def apply(value: BigDecimal): Option[BigDecimal] = Some(value) + } + + case object Plus extends Numeric { + val op = "+" + def apply(value: BigInt): BigInt = value + def apply(value: BigDecimal): Option[BigDecimal] = Some(value) + } + + case object Minus extends Numeric { + val op = "-" + def apply(value: BigInt): BigInt = -value + def apply(value: BigDecimal): Option[BigDecimal] = Some(-value) + } + + case object Tilde extends Numeric { + val op = "~" + def apply(value: BigInt): BigInt = ~value + def apply(value: BigDecimal): Option[BigDecimal] = None + } + + case object Not extends Logical { + val op = "!" + def apply(value: Boolean): Boolean = !value + } + +} diff --git a/scalameta/trees/shared/src/main/scala/scala/meta/internal/trees/package.scala b/scalameta/trees/shared/src/main/scala/scala/meta/internal/trees/package.scala index abbaeb58b5..bb06ee8c61 100644 --- a/scalameta/trees/shared/src/main/scala/scala/meta/internal/trees/package.scala +++ b/scalameta/trees/shared/src/main/scala/scala/meta/internal/trees/package.scala @@ -43,7 +43,7 @@ package object trees { // some heuristic is needed to govern associativity and precedence of unquoted operators def isLeftAssoc: Boolean = value.last != ':' - def isUnaryOp: Boolean = Set("-", "+", "~", "!").contains(value) + def isUnaryOp: Boolean = Unary.opMap.contains(value) def isAssignmentOp = value match { case "!=" | "<=" | ">=" | "" => false diff --git a/tests/shared/src/test/scala/scala/meta/tests/parsers/LitSuite.scala b/tests/shared/src/test/scala/scala/meta/tests/parsers/LitSuite.scala index 8ba1a6cb59..e19adbc7c3 100644 --- a/tests/shared/src/test/scala/scala/meta/tests/parsers/LitSuite.scala +++ b/tests/shared/src/test/scala/scala/meta/tests/parsers/LitSuite.scala @@ -271,9 +271,9 @@ class LitSuite extends ParseSuite { } test("unary: +1") { - runTestAssert[Stat]("+1")(Term.ApplyUnary(tname("+"), lit(1))) - val tree = Term.ApplyUnary(tname("+"), Term.Apply(lit(1), List(lit(0)))) - runTestAssert[Stat]("+1(0)")(tree) + runTestAssert[Stat]("+1", "1")(lit(1)) + val tree = Term.Apply(lit(1), List(lit(0))) + runTestAssert[Stat]("+1(0)", "1(0)")(tree) } test("unary: -1") { @@ -283,9 +283,9 @@ class LitSuite extends ParseSuite { } test("unary: ~1") { - runTestAssert[Stat]("~1")(Term.ApplyUnary(tname("~"), lit(1))) - val tree = Term.ApplyUnary(tname("~"), Term.Apply(lit(1), List(lit(0)))) - runTestAssert[Stat]("~1(0)")(tree) + runTestAssert[Stat]("~1", "-2")(lit(-2)) + val tree = Term.Apply(lit(-2), List(lit(0))) + runTestAssert[Stat]("~1(0)", "-2(0)")(tree) } test("unary: !1") { @@ -295,9 +295,9 @@ class LitSuite extends ParseSuite { } test("unary: +1.0") { - runTestAssert[Stat]("+1.0", "+1.0d")(Term.ApplyUnary(tname("+"), lit(1d))) - val tree = Term.ApplyUnary(tname("+"), Term.Apply(lit(1d), List(lit(0)))) - runTestAssert[Stat]("+1.0(0)", "+1.0d(0)")(tree) + runTestAssert[Stat]("+1.0", "1.0d")(lit(1d)) + val tree = Term.Apply(lit(1d), List(lit(0))) + runTestAssert[Stat]("+1.0(0)", "1.0d(0)")(tree) } test("unary: -1.0") { @@ -307,9 +307,12 @@ class LitSuite extends ParseSuite { } test("unary: ~1.0") { - runTestAssert[Stat]("~1.0", "~1.0d")(Term.ApplyUnary(tname("~"), lit(1d))) - val tree = Term.ApplyUnary(tname("~"), Term.Apply(lit(1d), List(lit(0)))) - runTestAssert[Stat]("~1.0(0)", "~1.0d(0)")(tree) + def error(code: String) = + s"""|:1: error: bad unary op `~` for floating-point + |~$code + | ^""".stripMargin + runTestError[Stat]("~1.0", error("1.0")) + runTestError[Stat]("~1.0(0)", error("1.0(0)")) } test("unary: !1.0") { @@ -319,15 +322,15 @@ class LitSuite extends ParseSuite { } test("unary: !true") { - runTestAssert[Stat]("!true")(Term.ApplyUnary(tname("!"), lit(true))) - val tree = Term.ApplyUnary(tname("!"), Term.Apply(lit(true), List(lit(0)))) - runTestAssert[Stat]("!true(0)")(tree) + runTestAssert[Stat]("!true", "false")(lit(false)) + val tree = Term.Apply(lit(false), List(lit(0))) + runTestAssert[Stat]("!true(0)", "false(0)")(tree) } test("unary: !false") { - runTestAssert[Stat]("!false")(Term.ApplyUnary(tname("!"), lit(false))) - val tree = Term.ApplyUnary(tname("!"), Term.Apply(lit(false), List(lit(0)))) - runTestAssert[Stat]("!false(0)")(tree) + runTestAssert[Stat]("!false", "true")(lit(true)) + val tree = Term.Apply(lit(true), List(lit(0))) + runTestAssert[Stat]("!false(0)", "true(0)")(tree) } test("scalatest-like infix without literal") { diff --git a/tests/shared/src/test/scala/scala/meta/tests/prettyprinters/TreeSyntaxSuite.scala b/tests/shared/src/test/scala/scala/meta/tests/prettyprinters/TreeSyntaxSuite.scala index 93d6023581..320ef8302b 100644 --- a/tests/shared/src/test/scala/scala/meta/tests/prettyprinters/TreeSyntaxSuite.scala +++ b/tests/shared/src/test/scala/scala/meta/tests/prettyprinters/TreeSyntaxSuite.scala @@ -76,13 +76,13 @@ class TreeSyntaxSuite extends scala.meta.tests.parsers.ParseSuite { testBlockAddNL("Foo.bar") testBlockAddNL("10") testBlockAddNL("-10") - testBlockAddNL("~10") + testBlockAddNL("~10", "-11") testBlockAddNL("10.0d") testBlockAddNL("-10.0d") testBlockAddNL("true") testBlockAddNL("false") - testBlockAddNL("!true") - testBlockAddNL("!false") + testBlockAddNL("!true", "false") + testBlockAddNL("!false", "true") testBlockNoNL("-{10}", "-{\n 10\n}") testBlockAddNL("foo(bar)") testBlockAddNL("foo[Bar]") diff --git a/tests/shared/src/test/scala/scala/meta/tests/trees/TreeSuite.scala b/tests/shared/src/test/scala/scala/meta/tests/trees/TreeSuite.scala index 30a4991d63..f9450e5e44 100644 --- a/tests/shared/src/test/scala/scala/meta/tests/trees/TreeSuite.scala +++ b/tests/shared/src/test/scala/scala/meta/tests/trees/TreeSuite.scala @@ -3,10 +3,29 @@ package trees import munit._ import scala.meta._ +import scala.meta.internal.trees._ class TreeSuite extends FunSuite { test("Name.unapply") { assert(Name.unapply(q"a").contains("a")) assert(Name.unapply(t"a").contains("a")) } + + Seq( + ("+", Unary.Plus), + ("-", Unary.Minus), + ("~", Unary.Tilde), + ("!", Unary.Not) + ).foreach { case (op, unary) => + test(s"Unary.$unary") { + assertEquals(unary.op, op) + assertEquals(Unary.opMap.get(op), Some(unary)) + assert(op.isUnaryOp) + } + } + + test(s"Unary opMap size") { + assertEquals(Unary.opMap.size, 4) + } + }