Skip to content

Commit

Permalink
Merge pull request #855 from asottile/shlex-join
Browse files Browse the repository at this point in the history
automatically rewrite to shlex.join in --py38-plus
  • Loading branch information
asottile committed Jul 3, 2023
2 parents 908b055 + 193fc2f commit e7fc4db
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 0 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,15 @@ Availability:
...
```

### shlex.join

Availability:
- `--py38-plus` is passed on the commandline.

```diff
-' '.join(shlex.quote(arg) for arg in cmd)
+shlex.join(cmd)
```

### replace `@functools.lru_cache(maxsize=None)` with shorthand

Expand Down
62 changes: 62 additions & 0 deletions pyupgrade/_plugins/shlex_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

import ast
import functools
from typing import Iterable

from tokenize_rt import NON_CODING_TOKENS
from tokenize_rt import Offset
from tokenize_rt import Token

from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._data import register
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
from pyupgrade._token_helpers import find_open_paren
from pyupgrade._token_helpers import find_token
from pyupgrade._token_helpers import victims


def _fix_shlex_join(i: int, tokens: list[Token], *, arg: ast.expr) -> None:
j = find_open_paren(tokens, i)
comp_victims = victims(tokens, j, arg, gen=True)
k = find_token(tokens, comp_victims.arg_index, 'in') + 1
while tokens[k].name in NON_CODING_TOKENS:
k += 1
tokens[comp_victims.ends[0]:comp_victims.ends[-1] + 1] = [Token('OP', ')')]
tokens[i:k] = [Token('CODE', 'shlex.join'), Token('OP', '(')]


@register(ast.Call)
def visit_Call(
state: State,
node: ast.Call,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
if state.settings.min_version < (3, 8):
return

if (
isinstance(node.func, ast.Attribute) and
isinstance(node.func.value, ast.Constant) and
isinstance(node.func.value.value, str) and
node.func.attr == 'join' and
not node.keywords and
len(node.args) == 1 and
isinstance(node.args[0], (ast.ListComp, ast.GeneratorExp)) and
isinstance(node.args[0].elt, ast.Call) and
isinstance(node.args[0].elt.func, ast.Attribute) and
isinstance(node.args[0].elt.func.value, ast.Name) and
node.args[0].elt.func.value.id == 'shlex' and
node.args[0].elt.func.attr == 'quote' and
not node.args[0].elt.keywords and
len(node.args[0].elt.args) == 1 and
isinstance(node.args[0].elt.args[0], ast.Name) and
len(node.args[0].generators) == 1 and
isinstance(node.args[0].generators[0].target, ast.Name) and
not node.args[0].generators[0].ifs and
not node.args[0].generators[0].is_async and
node.args[0].elt.args[0].id == node.args[0].generators[0].target.id
):
func = functools.partial(_fix_shlex_join, arg=node.args[0])
yield ast_to_offset(node), func
72 changes: 72 additions & 0 deletions tests/features/shlex_join_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

import pytest

from pyupgrade._data import Settings
from pyupgrade._main import _fix_plugins


@pytest.mark.parametrize(
('s', 'version'),
(
pytest.param(
'from shlex import quote\n'
'" ".join(quote(arg) for arg in cmd)\n',
(3, 8),
id='quote from-imported',
),
pytest.param(
'import shlex\n'
'" ".join(shlex.quote(arg) for arg in cmd)\n',
(3, 7),
id='3.8+ feature',
),
),
)
def test_shlex_join_noop(s, version):
assert _fix_plugins(s, settings=Settings(min_version=version)) == s


@pytest.mark.parametrize(
('s', 'expected'),
(
pytest.param(
'import shlex\n'
'" ".join(shlex.quote(arg) for arg in cmd)\n',
'import shlex\n'
'shlex.join(cmd)\n',
id='generator expression',
),
pytest.param(
'import shlex\n'
'" ".join([shlex.quote(arg) for arg in cmd])\n',
'import shlex\n'
'shlex.join(cmd)\n',
id='list comprehension',
),
pytest.param(
'import shlex\n'
'" ".join([shlex.quote(arg) for arg in cmd],)\n',
'import shlex\n'
'shlex.join(cmd)\n',
id='removes trailing comma',
),
pytest.param(
'import shlex\n'
'" ".join([shlex.quote(arg) for arg in ["a", "b", "c"]],)\n',
'import shlex\n'
'shlex.join(["a", "b", "c"])\n',
id='more complicated iterable',
),
),
)
def test_shlex_join_fixes(s, expected):
assert _fix_plugins(s, settings=Settings(min_version=(3, 8))) == expected

0 comments on commit e7fc4db

Please sign in to comment.