Skip to content

Commit

Permalink
Merge pull request #901 from UnknownPlatypus/defaultdict-lambdas
Browse files Browse the repository at this point in the history
Defaultdict lambdas
  • Loading branch information
asottile committed Oct 7, 2023
2 parents 72041b1 + f001635 commit d1518e8
Show file tree
Hide file tree
Showing 4 changed files with 328 additions and 0 deletions.
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,32 @@ Sample `.pre-commit-config.yaml`:
+{a: b for a, b in y}
```

### Replace unnecessary lambdas in `collections.defaultdict` calls

```diff
-defaultdict(lambda: [])
+defaultdict(list)
-defaultdict(lambda: list())
+defaultdict(list)
-defaultdict(lambda: {})
+defaultdict(dict)
-defaultdict(lambda: dict())
+defaultdict(dict)
-defaultdict(lambda: ())
+defaultdict(tuple)
-defaultdict(lambda: tuple())
+defaultdict(tuple)
-defaultdict(lambda: set())
+defaultdict(set)
-defaultdict(lambda: 0)
+defaultdict(int)
-defaultdict(lambda: 0.0)
+defaultdict(float)
-defaultdict(lambda: 0j)
+defaultdict(complex)
-defaultdict(lambda: '')
+defaultdict(str)
```

### Format Specifiers

Expand Down
1 change: 1 addition & 0 deletions pyupgrade/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class State(NamedTuple):
RECORD_FROM_IMPORTS = frozenset((
'__future__',
'asyncio',
'collections',
'functools',
'mmap',
'os',
Expand Down
81 changes: 81 additions & 0 deletions pyupgrade/_plugins/defauldict_lambda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

import ast
import functools
from typing import Iterable

from tokenize_rt import Offset
from tokenize_rt import Token

from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._ast_helpers import is_name_attr
from pyupgrade._data import register
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
from pyupgrade._token_helpers import find_op
from pyupgrade._token_helpers import parse_call_args


def _eligible_lambda_replacement(lambda_expr: ast.Lambda) -> str | None:
if isinstance(lambda_expr.body, ast.Constant):
if lambda_expr.body.value == 0:
return type(lambda_expr.body.value).__name__
elif lambda_expr.body.value == '':
return 'str'
else:
return None
elif isinstance(lambda_expr.body, ast.List) and not lambda_expr.body.elts:
return 'list'
elif isinstance(lambda_expr.body, ast.Tuple) and not lambda_expr.body.elts:
return 'tuple'
elif isinstance(lambda_expr.body, ast.Dict) and not lambda_expr.body.keys:
return 'dict'
elif (
isinstance(lambda_expr.body, ast.Call) and
isinstance(lambda_expr.body.func, ast.Name) and
not lambda_expr.body.args and
not lambda_expr.body.keywords and
lambda_expr.body.func.id in {'dict', 'list', 'set', 'tuple'}
):
return lambda_expr.body.func.id
else:
return None


def _fix_defaultdict_first_arg(
i: int,
tokens: list[Token],
*,
replacement: str,
) -> None:
start = find_op(tokens, i, '(')
func_args, end = parse_call_args(tokens, start)

tokens[slice(*func_args[0])] = [Token('CODE', replacement)]


@register(ast.Call)
def visit_Call(
state: State,
node: ast.Call,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
if (
is_name_attr(
node.func,
state.from_imports,
('collections',),
('defaultdict',),
) and
node.args and
isinstance(node.args[0], ast.Lambda)
):
replacement = _eligible_lambda_replacement(node.args[0])
if replacement is None:
return

func = functools.partial(
_fix_defaultdict_first_arg,
replacement=replacement,
)
yield ast_to_offset(node), func
220 changes: 220 additions & 0 deletions tests/features/defauldict_lambda_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from __future__ import annotations

import pytest

from pyupgrade._data import Settings
from pyupgrade._main import _fix_plugins


@pytest.mark.parametrize(
('s',),
(
pytest.param(
'from collections import defaultdict as dd\n\n'
'dd(lambda: set())\n',
id='not following as imports',
),
pytest.param(
'from collections2 import defaultdict\n\n'
'dd(lambda: dict())\n',
id='not following unknown import',
),
pytest.param(
'from .collections import defaultdict\n'
'defaultdict(lambda: list())\n',
id='relative imports',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: {1}))\n',
id='non empty set',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: [1]))\n'
'defaultdict(lambda: list([1])))\n',
id='non empty list',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: {1: 2})\n',
id='non empty dict, literal',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: dict([(1,2),])))\n',
id='non empty dict, call with args',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: dict(a=[1]))\n',
id='non empty dict, call with kwargs',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: (1,))\n',
id='non empty tuple, literal',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: tuple([1]))\n',
id='non empty tuple, calls with arg',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: "AAA")\n'
'defaultdict(lambda: \'BBB\')\n',
id='non empty string',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: 10)\n'
'defaultdict(lambda: -2)\n',
id='non zero integer',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: 0.2)\n'
'defaultdict(lambda: 0.00000001)\n'
'defaultdict(lambda: -2.3)\n',
id='non zero float',
),
pytest.param(
'import collections\n'
'collections.defaultdict(lambda: None)\n',
id='lambda: None is not equivalent to defaultdict()',
),
),
)
def test_fix_noop(s):
assert _fix_plugins(s, settings=Settings()) == s


@pytest.mark.parametrize(
('s', 'expected'),
(
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: set())\n',
'from collections import defaultdict\n\n'
'defaultdict(set)\n',
id='call with attr, set()',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: list())\n',
'from collections import defaultdict\n\n'
'defaultdict(list)\n',
id='call with attr, list()',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: dict())\n',
'from collections import defaultdict\n\n'
'defaultdict(dict)\n',
id='call with attr, dict()',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: tuple())\n',
'from collections import defaultdict\n\n'
'defaultdict(tuple)\n',
id='call with attr, tuple()',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: [])\n',
'from collections import defaultdict\n\n'
'defaultdict(list)\n',
id='call with attr, []',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: {})\n',
'from collections import defaultdict\n\n'
'defaultdict(dict)\n',
id='call with attr, {}',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: ())\n',
'from collections import defaultdict\n\n'
'defaultdict(tuple)\n',
id='call with attr, ()',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: "")\n',
'from collections import defaultdict\n\n'
'defaultdict(str)\n',
id='call with attr, empty string (double quote)',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: \'\')\n',
'from collections import defaultdict\n\n'
'defaultdict(str)\n',
id='call with attr, empty string (single quote)',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: 0)\n',
'from collections import defaultdict\n\n'
'defaultdict(int)\n',
id='call with attr, int',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: 0.0)\n',
'from collections import defaultdict\n\n'
'defaultdict(float)\n',
id='call with attr, float',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: 0.0000)\n',
'from collections import defaultdict\n\n'
'defaultdict(float)\n',
id='call with attr, long float',
),
pytest.param(
'from collections import defaultdict\n\n'
'defaultdict(lambda: [], {1: []})\n',
'from collections import defaultdict\n\n'
'defaultdict(list, {1: []})\n',
id='defauldict with kwargs',
),
pytest.param(
'import collections\n\n'
'collections.defaultdict(lambda: set())\n'
'collections.defaultdict(lambda: list())\n'
'collections.defaultdict(lambda: dict())\n'
'collections.defaultdict(lambda: tuple())\n'
'collections.defaultdict(lambda: [])\n'
'collections.defaultdict(lambda: {})\n'
'collections.defaultdict(lambda: "")\n'
'collections.defaultdict(lambda: \'\')\n'
'collections.defaultdict(lambda: 0)\n'
'collections.defaultdict(lambda: 0.0)\n'
'collections.defaultdict(lambda: 0.00000)\n'
'collections.defaultdict(lambda: 0j)\n',
'import collections\n\n'
'collections.defaultdict(set)\n'
'collections.defaultdict(list)\n'
'collections.defaultdict(dict)\n'
'collections.defaultdict(tuple)\n'
'collections.defaultdict(list)\n'
'collections.defaultdict(dict)\n'
'collections.defaultdict(str)\n'
'collections.defaultdict(str)\n'
'collections.defaultdict(int)\n'
'collections.defaultdict(float)\n'
'collections.defaultdict(float)\n'
'collections.defaultdict(complex)\n',
id='call with attr',
),
),
)
def test_fix_defaultdict(s, expected):
ret = _fix_plugins(s, settings=Settings())
assert ret == expected

0 comments on commit d1518e8

Please sign in to comment.