Skip to content

Commit

Permalink
Merge pull request #885 from mxr/timeout-errors
Browse files Browse the repository at this point in the history
Unify exception replacing and rewrite some TimeoutError cases
  • Loading branch information
asottile committed Sep 23, 2023
2 parents 3bbf781 + 719c224 commit 19b1966
Show file tree
Hide file tree
Showing 5 changed files with 431 additions and 652 deletions.
1 change: 1 addition & 0 deletions pyupgrade/_data.py
Expand Up @@ -38,6 +38,7 @@ class State(NamedTuple):

RECORD_FROM_IMPORTS = frozenset((
'__future__',
'asyncio',
'functools',
'mmap',
'os',
Expand Down
165 changes: 165 additions & 0 deletions pyupgrade/_plugins/exceptions.py
@@ -0,0 +1,165 @@
from __future__ import annotations

import ast
import functools
from typing import Iterable
from typing import NamedTuple

from tokenize_rt import Offset
from tokenize_rt import Token

from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._data import register
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
from pyupgrade._data import Version
from pyupgrade._token_helpers import arg_str
from pyupgrade._token_helpers import find_op
from pyupgrade._token_helpers import parse_call_args
from pyupgrade._token_helpers import replace_name


class _Target(NamedTuple):
target: str
module: str | None
name: str
min_version: Version


_TARGETS = (
_Target('OSError', 'mmap', 'error', (3,)),
_Target('OSError', 'os', 'error', (3,)),
_Target('OSError', 'select', 'error', (3,)),
_Target('OSError', 'socket', 'error', (3,)),
_Target('OSError', None, 'IOError', (3,)),
_Target('OSError', None, 'EnvironmentError', (3,)),
_Target('OSError', None, 'WindowsError', (3,)),
_Target('TimeoutError', 'socket', 'timeout', (3, 10)),
_Target('TimeoutError', 'asyncio', 'TimeoutError', (3, 11)),
)


def _fix_except(
i: int,
tokens: list[Token],
*,
at_idx: dict[int, _Target],
) -> None:
# find all the arg strs in the tuple
except_index = i
while tokens[except_index].src != 'except':
except_index -= 1
start = find_op(tokens, except_index, '(')
func_args, end = parse_call_args(tokens, start)

# save the exceptions and remove the block
arg_strs = [arg_str(tokens, *arg) for arg in func_args]
del tokens[start:end]

# rewrite the block without dupes
args = []
for i, arg in enumerate(arg_strs):
target = at_idx.get(i)
if target is not None:
args.append(target.target)
else:
args.append(arg)

unique_args = tuple(dict.fromkeys(args))

if len(unique_args) > 1:
joined = '({})'.format(', '.join(unique_args))
elif tokens[start - 1].name != 'UNIMPORTANT_WS':
joined = f' {unique_args[0]}'
else:
joined = unique_args[0]

new = Token('CODE', joined)
tokens.insert(start, new)


def _get_rewrite(
node: ast.AST,
state: State,
targets: list[_Target],
) -> _Target | None:
for target in targets:
if (
target.module is None and
isinstance(node, ast.Name) and
node.id == target.name
):
return target
elif (
target.module is not None and
isinstance(node, ast.Name) and
node.id == target.name and
node.id in state.from_imports[target.module]
):
return target
elif (
target.module is not None and
isinstance(node, ast.Attribute) and
isinstance(node.value, ast.Name) and
node.attr == target.name and
node.value.id == target.module
):
return target
else:
return None


def _alias_cbs(
node: ast.expr,
state: State,
targets: list[_Target],
) -> Iterable[tuple[Offset, TokenFunc]]:
target = _get_rewrite(node, state, targets)
if target is not None:
func = functools.partial(
replace_name,
name=target.name,
new=target.target,
)
yield ast_to_offset(node), func


@register(ast.Raise)
def visit_Raise(
state: State,
node: ast.Raise,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
targets = [
target for target in _TARGETS
if state.settings.min_version >= target.min_version
]
if node.exc is not None:
yield from _alias_cbs(node.exc, state, targets)
if isinstance(node.exc, ast.Call):
yield from _alias_cbs(node.exc.func, state, targets)


@register(ast.Try)
def visit_Try(
state: State,
node: ast.Try,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
targets = [
target for target in _TARGETS
if state.settings.min_version >= target.min_version
]
for handler in node.handlers:
if isinstance(handler.type, ast.Tuple):
at_idx = {}
for i, elt in enumerate(handler.type.elts):
target = _get_rewrite(elt, state, targets)
if target is not None:
at_idx[i] = target

if at_idx:
func = functools.partial(_fix_except, at_idx=at_idx)
yield ast_to_offset(handler.type), func
elif handler.type is not None:
yield from _alias_cbs(handler.type, state, targets)
136 changes: 0 additions & 136 deletions pyupgrade/_plugins/oserror_aliases.py

This file was deleted.

0 comments on commit 19b1966

Please sign in to comment.