Skip to content

Commit

Permalink
Use typing.Literal available in Python 3.8+ (#3365)
Browse files Browse the repository at this point in the history
  • Loading branch information
asottile committed Mar 17, 2024
1 parent cc892b0 commit da892d6
Show file tree
Hide file tree
Showing 12 changed files with 26 additions and 44 deletions.
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@
("py:class", "_HttplibHTTPResponse"),
("py:class", "_HttplibHTTPMessage"),
("py:class", "TracebackType"),
("py:class", "Literal"),
("py:class", "email.errors.MessageDefect"),
("py:class", "MessageDefect"),
("py:class", "http.client.HTTPMessage"),
Expand Down
8 changes: 4 additions & 4 deletions src/urllib3/_base_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class ProxyConfig(typing.NamedTuple):
ssl_context: ssl.SSLContext | None
use_forwarding_for_https: bool
assert_hostname: None | str | Literal[False]
assert_hostname: None | str | typing.Literal[False]
assert_fingerprint: str | None


Expand All @@ -28,7 +28,7 @@ class _ResponseOptions(typing.NamedTuple):

if typing.TYPE_CHECKING:
import ssl
from typing import Literal, Protocol
from typing import Protocol

from .response import BaseHTTPResponse

Expand Down Expand Up @@ -124,7 +124,7 @@ class BaseHTTPSConnection(BaseHTTPConnection, Protocol):

# Certificate verification methods
cert_reqs: int | str | None
assert_hostname: None | str | Literal[False]
assert_hostname: None | str | typing.Literal[False]
assert_fingerprint: str | None
ssl_context: ssl.SSLContext | None

Expand Down Expand Up @@ -155,7 +155,7 @@ def __init__(
proxy: Url | None = None,
proxy_config: ProxyConfig | None = None,
cert_reqs: int | str | None = None,
assert_hostname: None | str | Literal[False] = None,
assert_hostname: None | str | typing.Literal[False] = None,
assert_fingerprint: str | None = None,
server_hostname: str | None = None,
ssl_context: ssl.SSLContext | None = None,
Expand Down
8 changes: 3 additions & 5 deletions src/urllib3/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from socket import timeout as SocketTimeout

if typing.TYPE_CHECKING:
from typing import Literal

from .response import HTTPResponse
from .util.ssl_ import _TYPE_PEER_CERT_RET_DICT
from .util.ssltransport import SSLTransport
Expand Down Expand Up @@ -523,7 +521,7 @@ def __init__(
proxy: Url | None = None,
proxy_config: ProxyConfig | None = None,
cert_reqs: int | str | None = None,
assert_hostname: None | str | Literal[False] = None,
assert_hostname: None | str | typing.Literal[False] = None,
assert_fingerprint: str | None = None,
server_hostname: str | None = None,
ssl_context: ssl.SSLContext | None = None,
Expand Down Expand Up @@ -577,7 +575,7 @@ def set_cert(
cert_reqs: int | str | None = None,
key_password: str | None = None,
ca_certs: str | None = None,
assert_hostname: None | str | Literal[False] = None,
assert_hostname: None | str | typing.Literal[False] = None,
assert_fingerprint: str | None = None,
ca_cert_dir: str | None = None,
ca_cert_data: None | str | bytes = None,
Expand Down Expand Up @@ -742,7 +740,7 @@ def _ssl_wrap_socket_and_match_hostname(
ca_certs: str | None,
ca_cert_dir: str | None,
ca_cert_data: None | str | bytes,
assert_hostname: None | str | Literal[False],
assert_hostname: None | str | typing.Literal[False],
assert_fingerprint: str | None,
server_hostname: str | None,
ssl_context: ssl.SSLContext | None,
Expand Down
5 changes: 2 additions & 3 deletions src/urllib3/connectionpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@

if typing.TYPE_CHECKING:
import ssl
from typing import Literal

from typing_extensions import Self

Expand Down Expand Up @@ -103,7 +102,7 @@ def __exit__(
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> Literal[False]:
) -> typing.Literal[False]:
self.close()
# Return False to re-raise any potential exceptions
return False
Expand Down Expand Up @@ -1002,7 +1001,7 @@ def __init__(
ssl_version: int | str | None = None,
ssl_minimum_version: ssl.TLSVersion | None = None,
ssl_maximum_version: ssl.TLSVersion | None = None,
assert_hostname: str | Literal[False] | None = None,
assert_hostname: str | typing.Literal[False] | None = None,
assert_fingerprint: str | None = None,
ca_cert_dir: str | None = None,
**conn_kw: typing.Any,
Expand Down
4 changes: 1 addition & 3 deletions src/urllib3/contrib/socks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,8 @@
except ImportError:
ssl = None # type: ignore[assignment]

from typing import TypedDict


class _TYPE_SOCKS_OPTIONS(TypedDict):
class _TYPE_SOCKS_OPTIONS(typing.TypedDict):
socks_version: int
proxy_host: str | None
proxy_port: str | None
Expand Down
5 changes: 2 additions & 3 deletions src/urllib3/poolmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

if typing.TYPE_CHECKING:
import ssl
from typing import Literal

from typing_extensions import Self

Expand Down Expand Up @@ -222,7 +221,7 @@ def __exit__(
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> Literal[False]:
) -> typing.Literal[False]:
self.clear()
# Return False to re-raise any potential exceptions
return False
Expand Down Expand Up @@ -553,7 +552,7 @@ def __init__(
proxy_headers: typing.Mapping[str, str] | None = None,
proxy_ssl_context: ssl.SSLContext | None = None,
use_forwarding_for_https: bool = False,
proxy_assert_hostname: None | str | Literal[False] = None,
proxy_assert_hostname: None | str | typing.Literal[False] = None,
proxy_assert_fingerprint: str | None = None,
**connection_pool_kw: typing.Any,
) -> None:
Expand Down
4 changes: 1 addition & 3 deletions src/urllib3/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@
from .util.retry import Retry

if typing.TYPE_CHECKING:
from typing import Literal

from .connectionpool import HTTPConnectionPool

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -347,7 +345,7 @@ def __init__(
self._decoder: ContentDecoder | None = None
self.length_remaining: int | None

def get_redirect_location(self) -> str | None | Literal[False]:
def get_redirect_location(self) -> str | None | typing.Literal[False]:
"""
Should we redirect and where to?
Expand Down
4 changes: 2 additions & 2 deletions src/urllib3/util/ssl_.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _is_has_never_check_common_name_reliable(

if typing.TYPE_CHECKING:
from ssl import VerifyMode
from typing import Literal, TypedDict
from typing import TypedDict

from .ssltransport import SSLTransport as SSLTransportType

Expand Down Expand Up @@ -365,7 +365,7 @@ def ssl_wrap_socket(
ca_cert_dir: str | None = ...,
key_password: str | None = ...,
ca_cert_data: None | str | bytes = ...,
tls_in_tls: Literal[False] = ...,
tls_in_tls: typing.Literal[False] = ...,
) -> ssl.SSLSocket:
...

Expand Down
6 changes: 2 additions & 4 deletions src/urllib3/util/ssltransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from ..exceptions import ProxySchemeUnsupported

if typing.TYPE_CHECKING:
from typing import Literal

from typing_extensions import Self

from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
Expand Down Expand Up @@ -175,12 +173,12 @@ def close(self) -> None:

@typing.overload
def getpeercert(
self, binary_form: Literal[False] = ...
self, binary_form: typing.Literal[False] = ...
) -> _TYPE_PEER_CERT_RET_DICT | None:
...

@typing.overload
def getpeercert(self, binary_form: Literal[True]) -> bytes | None:
def getpeercert(self, binary_form: typing.Literal[True]) -> bytes | None:
...

def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
Expand Down
3 changes: 1 addition & 2 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@

if typing.TYPE_CHECKING:
import ssl
from typing import Literal


_RT = typing.TypeVar("_RT") # return type
Expand Down Expand Up @@ -266,7 +265,7 @@ def __exit__(
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> Literal[False]:
) -> typing.Literal[False]:
self.uninstall()
return False

Expand Down
13 changes: 5 additions & 8 deletions test/test_ssltransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
from urllib3.util import ssl_
from urllib3.util.ssltransport import SSLTransport

if typing.TYPE_CHECKING:
from typing import Literal

# consume_socket can iterate forever, we add timeouts to prevent halting.
PER_TEST_TIMEOUT = 60

Expand All @@ -34,12 +31,12 @@ def server_client_ssl_contexts() -> tuple[ssl.SSLContext, ssl.SSLContext]:


@typing.overload
def sample_request(binary: Literal[True] = ...) -> bytes:
def sample_request(binary: typing.Literal[True] = ...) -> bytes:
...


@typing.overload
def sample_request(binary: Literal[False]) -> str:
def sample_request(binary: typing.Literal[False]) -> str:
...


Expand All @@ -54,20 +51,20 @@ def sample_request(binary: bool = True) -> bytes | str:


def validate_request(
provided_request: bytearray, binary: Literal[False, True] = True
provided_request: bytearray, binary: typing.Literal[False, True] = True
) -> None:
assert provided_request is not None
expected_request = sample_request(binary)
assert provided_request == expected_request


@typing.overload
def sample_response(binary: Literal[True] = ...) -> bytes:
def sample_response(binary: typing.Literal[True] = ...) -> bytes:
...


@typing.overload
def sample_response(binary: Literal[False]) -> str:
def sample_response(binary: typing.Literal[False]) -> str:
...


Expand Down
9 changes: 3 additions & 6 deletions test/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@

from . import clear_warnings

if typing.TYPE_CHECKING:
from typing import Literal

# This number represents a time in seconds, it doesn't mean anything in
# isolation. Setting to a high-ish value to avoid conflicts with the smaller
# numbers used for timeouts
Expand Down Expand Up @@ -516,7 +513,7 @@ def test_netloc(self, url: str, expected_netloc: str | None) -> None:

@pytest.mark.parametrize("url, expected_url", url_vulnerabilities)
def test_url_vulnerabilities(
self, url: str, expected_url: Literal[False] | Url
self, url: str, expected_url: typing.Literal[False] | Url
) -> None:
if expected_url is False:
with pytest.raises(LocationParseError):
Expand Down Expand Up @@ -748,7 +745,7 @@ def test_timeout_elapsed(self, time_monotonic: MagicMock) -> None:
def test_is_fp_closed_object_supports_closed(self) -> None:
class ClosedFile:
@property
def closed(self) -> Literal[True]:
def closed(self) -> typing.Literal[True]:
return True

assert is_fp_closed(ClosedFile())
Expand All @@ -764,7 +761,7 @@ def fp(self) -> None:
def test_is_fp_closed_object_has_fp(self) -> None:
class FpFile:
@property
def fp(self) -> Literal[True]:
def fp(self) -> typing.Literal[True]:
return True

assert not is_fp_closed(FpFile())
Expand Down

0 comments on commit da892d6

Please sign in to comment.