Skip to content

Commit

Permalink
Merge pull request #898 from asottile/constant-fold-types
Browse files Browse the repository at this point in the history
constant fold isinstance / issubclass / except
  • Loading branch information
asottile committed Sep 23, 2023
2 parents 0d46cba + f926532 commit 8943547
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 36 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,22 @@ A fix for [python-modernize/python-modernize#178]

[python-modernize/python-modernize#178]: https://github.com/python-modernize/python-modernize/issues/178

### constant fold `isinstance` / `issubclass` / `except`

```diff
-isinstance(x, (int, int))
+isinstance(x, int)

-issubclass(y, (str, str))
+issubclass(y, str)

try:
raises()
-except (Error1, Error1, Error2):
+except (Error1, Error2):
pass
```

### unittest deprecated aliases

Rewrites [deprecated unittest method aliases](https://docs.python.org/3/library/unittest.html#deprecated-aliases) to their non-deprecated forms.
Expand Down
10 changes: 10 additions & 0 deletions pyupgrade/_ast_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,13 @@ def is_async_listcomp(node: ast.ListComp) -> bool:
any(gen.is_async for gen in node.generators) or
contains_await(node)
)


def is_type_check(node: ast.AST) -> bool:
return (
isinstance(node, ast.Call) and
isinstance(node.func, ast.Name) and
node.func.id in {'isinstance', 'issubclass'} and
len(node.args) == 2 and
not has_starargs(node)
)
64 changes: 64 additions & 0 deletions pyupgrade/_plugins/constant_fold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

import ast
from typing import Iterable

from tokenize_rt import Offset

from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._ast_helpers import is_type_check
from pyupgrade._data import register
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
from pyupgrade._token_helpers import constant_fold_tuple


def _to_name(node: ast.AST) -> str | None:
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
base = _to_name(node.value)
if base is None:
return None
else:
return f'{base}.{node.attr}'
else:
return None


def _can_constant_fold(node: ast.Tuple) -> bool:
seen = set()
for el in node.elts:
name = _to_name(el)
if name is not None:
if name in seen:
return True
else:
seen.add(name)
else:
return False


def _cbs(node: ast.AST | None) -> Iterable[tuple[Offset, TokenFunc]]:
if isinstance(node, ast.Tuple) and _can_constant_fold(node):
yield ast_to_offset(node), constant_fold_tuple


@register(ast.Call)
def visit_Call(
state: State,
node: ast.Call,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
if is_type_check(node):
yield from _cbs(node.args[1])


@register(ast.Try)
def visit_Try(
state: State,
node: ast.Try,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
for handler in node.handlers:
yield from _cbs(handler.type)
31 changes: 5 additions & 26 deletions pyupgrade/_plugins/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
from pyupgrade._data import Version
from pyupgrade._token_helpers import arg_str
from pyupgrade._token_helpers import constant_fold_tuple
from pyupgrade._token_helpers import find_op
from pyupgrade._token_helpers import parse_call_args
from pyupgrade._token_helpers import replace_name
Expand Down Expand Up @@ -45,34 +45,13 @@ def _fix_except(
*,
at_idx: dict[int, _Target],
) -> None:
# find all the arg strs in the tuple
except_index = i
while tokens[except_index].src != 'except':
except_index -= 1
start = find_op(tokens, except_index, '(')
start = find_op(tokens, i, '(')
func_args, end = parse_call_args(tokens, start)

arg_strs = [arg_str(tokens, *arg) for arg in func_args]
for i, target in reversed(at_idx.items()):
tokens[slice(*func_args[i])] = [Token('NAME', target.target)]

# rewrite the block without dupes
args = []
for i, arg in enumerate(arg_strs):
target = at_idx.get(i)
if target is not None:
args.append(target.target)
else:
args.append(arg)

unique_args = tuple(dict.fromkeys(args))

if len(unique_args) > 1:
joined = '({})'.format(', '.join(unique_args))
elif tokens[start - 1].name != 'UNIMPORTANT_WS':
joined = f' {unique_args[0]}'
else:
joined = unique_args[0]

tokens[start:end] = [Token('CODE', joined)]
constant_fold_tuple(start, tokens)


def _get_rewrite(
Expand Down
13 changes: 3 additions & 10 deletions pyupgrade/_plugins/six_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tokenize_rt import Offset

from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._ast_helpers import is_type_check
from pyupgrade._data import register
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
Expand Down Expand Up @@ -36,14 +37,6 @@
}


def _is_type_check(node: ast.AST | None) -> bool:
return (
isinstance(node, ast.Call) and
isinstance(node.func, ast.Name) and
node.func.id in {'isinstance', 'issubclass'}
)


@register(ast.Attribute)
def visit_Attribute(
state: State,
Expand All @@ -62,7 +55,7 @@ def visit_Attribute(
):
return

if node.attr in NAMES_TYPE_CTX and _is_type_check(parent):
if node.attr in NAMES_TYPE_CTX and is_type_check(parent):
new = NAMES_TYPE_CTX[node.attr]
else:
new = NAMES[node.attr]
Expand Down Expand Up @@ -106,7 +99,7 @@ def visit_Name(
):
return

if node.id in NAMES_TYPE_CTX and _is_type_check(parent):
if node.id in NAMES_TYPE_CTX and is_type_check(parent):
new = NAMES_TYPE_CTX[node.id]
else:
new = NAMES[node.id]
Expand Down
17 changes: 17 additions & 0 deletions pyupgrade/_token_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,23 @@ def replace_argument(
tokens[start_idx:end_idx] = [Token('SRC', new)]


def constant_fold_tuple(i: int, tokens: list[Token]) -> None:
start = find_op(tokens, i, '(')
func_args, end = parse_call_args(tokens, start)
arg_strs = [arg_str(tokens, *arg) for arg in func_args]

unique_args = tuple(dict.fromkeys(arg_strs))

if len(unique_args) > 1:
joined = '({})'.format(', '.join(unique_args))
elif tokens[start - 1].name != 'UNIMPORTANT_WS':
joined = f' {unique_args[0]}'
else:
joined = unique_args[0]

tokens[start:end] = [Token('CODE', joined)]


def has_space_before(i: int, tokens: list[Token]) -> bool:
return i >= 1 and tokens[i - 1].name in {UNIMPORTANT_WS, 'INDENT'}

Expand Down
91 changes: 91 additions & 0 deletions tests/features/constant_fold_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from __future__ import annotations

import pytest

from pyupgrade._data import Settings
from pyupgrade._main import _fix_plugins


@pytest.mark.parametrize(
's',
(
pytest.param(
'isinstance(x, str)',
id='isinstance nothing duplicated',
),
pytest.param(
'issubclass(x, str)',
id='issubclass nothing duplicated',
),
pytest.param(
'try: ...\n'
'except Exception: ...\n',
id='try-except nothing duplicated',
),
pytest.param(
'isinstance(x, (str, (str,)))',
id='only consider flat tuples',
),
pytest.param(
'isinstance(x, (f(), a().g))',
id='only consider names and dotted names',
),
),
)
def test_constant_fold_noop(s):
assert _fix_plugins(s, settings=Settings()) == s


@pytest.mark.parametrize(
('s', 'expected'),
(
pytest.param(
'isinstance(x, (str, str, int))',
'isinstance(x, (str, int))',
id='isinstance',
),
pytest.param(
'issubclass(x, (str, str, int))',
'issubclass(x, (str, int))',
id='issubclass',
),
pytest.param(
'try: ...\n'
'except (Exception, Exception, TypeError): ...\n',
'try: ...\n'
'except (Exception, TypeError): ...\n',
id='except',
),
pytest.param(
'isinstance(x, (str, str))',
'isinstance(x, str)',
id='folds to 1',
),
pytest.param(
'isinstance(x, (a.b, a.b, a.c))',
'isinstance(x, (a.b, a.c))',
id='folds dotted names',
),
pytest.param(
'try: ...\n'
'except(a, a): ...\n',
'try: ...\n'
'except a: ...\n',
id='deduplication to 1 does not cause syntax error with except',
),
),
)
def test_constant_fold(s, expected):
assert _fix_plugins(s, settings=Settings()) == expected

0 comments on commit 8943547

Please sign in to comment.