Skip to content

Commit

Permalink
fix asottile#819, TimoutError aliases
Browse files Browse the repository at this point in the history
* socket.timeout 3.10+
* asyncio.TimeoutError 3.11+
* concurrent.futures.TimeoutError 3.11+
  • Loading branch information
lancelote committed May 13, 2023
1 parent 611d9e6 commit 8a8e164
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 0 deletions.
53 changes: 53 additions & 0 deletions pyupgrade/_plugins/timeout_error_aliases.py
@@ -0,0 +1,53 @@
from __future__ import annotations

import ast
import functools
from typing import Iterable

from tokenize_rt import Offset

from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._data import register
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
from pyupgrade._token_helpers import replace_name

ALIASES = {
('socket', 'timeout'): (3, 10),
('asyncio', 'TimeoutError'): (3, 11),
('futures', 'TimeoutError'): (3, 11),
}


@register(ast.Attribute)
def visit_Attribute(
state: State,
node: ast.Attribute,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
if (
isinstance(node.value, ast.Name) and
(node.value.id, node.attr) in ALIASES and
state.settings.min_version >= ALIASES[(node.value.id, node.attr)]
):
func = functools.partial(
replace_name,
name=node.attr,
new='TimeoutError',
)
yield ast_to_offset(node), func


@register(ast.Name)
def visit_Name(
state: State,
node: ast.Name,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
if node.id in state.from_imports['socket'] and node.id == 'timeout':
func = functools.partial(
replace_name,
name='timeout',
new='TimeoutError',
)
yield ast_to_offset(node), func
132 changes: 132 additions & 0 deletions tests/features/timeout_error_aliases_test.py
@@ -0,0 +1,132 @@
from __future__ import annotations

import pytest

from pyupgrade._data import Settings
from pyupgrade._main import _fix_plugins


@pytest.mark.parametrize(
('s', 'expected', 'version'),
(
pytest.param(
'import socket\n'
'raise socket.timeout("error")',
'import socket\n'
'raise TimeoutError("error")',
(3, 10),
id='rewriting socket.timeout',
),
pytest.param(
'from socket import timeout\n'
'raise timeout("error")',
'from socket import timeout\n'
'raise TimeoutError("error")',
(3, 10),
id='rewriting timeout',
),
pytest.param(
'import asyncio\n'
'raise asyncio.TimeoutError("error")',
'import asyncio\n'
'raise TimeoutError("error")',
(3, 11),
id='rewriting asyncio.TimeoutError',
),
pytest.param(
'from concurrent import futures\n'
'raise futures.TimeoutError("error")',
'from concurrent import futures\n'
'raise TimeoutError("error")',
(3, 11),
id='rewriting futures.TimeoutError',
),
),
)
def test_fix_timeout_error_alias(s, expected, version):
assert _fix_plugins(s, settings=Settings(min_version=version)) == expected


@pytest.mark.parametrize(
('s', 'version'),
(
pytest.param(
'import socket\n'
'raise socket.timeout("error")',
(3, 9),
id='socket.timeout not Python 3.10+',
),
pytest.param(
'import foo\n'
'raise foo.timeout("error")',
(3, 10),
id='timeout not from socket as attr',
),
pytest.param(
'from foo import timeout\n'
'raise timeout("error")',
(3, 10),
id='timeout not from socket',
),
pytest.param(
'import asyncio\n'
'raise asyncio.TimeoutError("error")',
(3, 10),
id='asyncio.TimeoutError not Python 3.11+',
),
pytest.param(
'from concurrent import futures\n'
'raise futures.TimeoutError("error")',
(3, 10),
id='concurrent.futures.TimeoutError not Python 3.11+',
),
),
)
def test_fix_timeout_error_alias_noop(s, version):
assert _fix_plugins(s, settings=Settings(min_version=version)) == s


@pytest.mark.parametrize(
('s', 'expected'),
(
pytest.param(
'import asyncio\n'
'try:\n'
' pass\n'
'except asyncio.TimeoutError as e:\n'
' pass',
'import asyncio\n'
'try:\n'
' pass\n'
'except TimeoutError as e:\n'
' pass',
id='rewriting asyncio.TimeoutError in try/except',
),
),
)
def test_alias_in_try_except(s, expected):
assert _fix_plugins(s, settings=Settings(min_version=(3, 11))) == expected

0 comments on commit 8a8e164

Please sign in to comment.