Skip to content

Commit

Permalink
Fix regression with connection upgrade (aio-libs#7879)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamsorcerer authored and Xiang Li committed Dec 4, 2023
1 parent 27e9422 commit 24434da
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 116 deletions.
1 change: 1 addition & 0 deletions CHANGES/7879.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed a regression where connection may get closed during upgrade. -- by :user:`Dreamsorcerer`
158 changes: 42 additions & 116 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import codecs
import contextlib
import dataclasses
import functools
import io
import re
Expand All @@ -26,7 +27,6 @@
cast,
)

import attr
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
from yarl import URL

Expand All @@ -42,14 +42,17 @@
)
from .compression_utils import HAS_BROTLI
from .formdata import FormData
from .hdrs import CONTENT_TYPE
from .helpers import (
BaseTimerContext,
BasicAuth,
HeadersMixin,
TimerNoop,
basicauth_from_netrc,
is_expected_content_type,
netrc_from_env,
noop,
parse_mimetype,
reify,
set_result,
)
Expand Down Expand Up @@ -88,30 +91,25 @@


_CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]")
json_re = re.compile(r"^application/(?:[\w.+-]+?\+)?json")


def _gen_default_accept_encoding() -> str:
return "gzip, deflate, br" if HAS_BROTLI else "gzip, deflate"


@attr.s(auto_attribs=True, frozen=True, slots=True)
@dataclasses.dataclass(frozen=True)
class ContentDisposition:
type: Optional[str]
parameters: "MappingProxyType[str, str]"
filename: Optional[str]


@attr.s(auto_attribs=True, frozen=True, slots=True)
@dataclasses.dataclass(frozen=True)
class RequestInfo:
url: URL
method: str
headers: "CIMultiDictProxy[str]"
real_url: URL = attr.ib()

@real_url.default
def real_url_default(self) -> URL:
return self.url
real_url: URL


class Fingerprint:
Expand Down Expand Up @@ -154,60 +152,7 @@ def check(self, transport: asyncio.Transport) -> None:
SSL_ALLOWED_TYPES = type(None)


def _merge_ssl_params(
ssl: Union["SSLContext", Literal[False], Fingerprint, None],
verify_ssl: Optional[bool],
ssl_context: Optional["SSLContext"],
fingerprint: Optional[bytes],
) -> Union["SSLContext", Literal[False], Fingerprint, None]:
if verify_ssl is not None and not verify_ssl:
warnings.warn(
"verify_ssl is deprecated, use ssl=False instead",
DeprecationWarning,
stacklevel=3,
)
if ssl is not None:
raise ValueError(
"verify_ssl, ssl_context, fingerprint and ssl "
"parameters are mutually exclusive"
)
else:
ssl = False
if ssl_context is not None:
warnings.warn(
"ssl_context is deprecated, use ssl=context instead",
DeprecationWarning,
stacklevel=3,
)
if ssl is not None:
raise ValueError(
"verify_ssl, ssl_context, fingerprint and ssl "
"parameters are mutually exclusive"
)
else:
ssl = ssl_context
if fingerprint is not None:
warnings.warn(
"fingerprint is deprecated, " "use ssl=Fingerprint(fingerprint) instead",
DeprecationWarning,
stacklevel=3,
)
if ssl is not None:
raise ValueError(
"verify_ssl, ssl_context, fingerprint and ssl "
"parameters are mutually exclusive"
)
else:
ssl = Fingerprint(fingerprint)
if not isinstance(ssl, SSL_ALLOWED_TYPES):
raise TypeError(
"ssl should be SSLContext, bool, Fingerprint or None, "
"got {!r} instead.".format(ssl)
)
return ssl


@attr.s(auto_attribs=True, slots=True, frozen=True)
@dataclasses.dataclass(frozen=True)
class ConnectionKey:
# the key should contain an information about used proxy / TLS
# to prevent reusing wrong connections from a pool
Expand All @@ -220,14 +165,6 @@ class ConnectionKey:
proxy_headers_hash: Optional[int] # hash(CIMultiDict)


def _is_expected_content_type(
response_content_type: str, expected_content_type: str
) -> bool:
if expected_content_type == "application/json":
return json_re.match(response_content_type) is not None
return expected_content_type in response_content_type


class ClientRequest:
GET_METHODS = {
hdrs.METH_GET,
Expand Down Expand Up @@ -270,7 +207,7 @@ def __init__(
compress: Optional[str] = None,
chunked: Optional[bool] = None,
expect100: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
loop: asyncio.AbstractEventLoop,
response_class: Optional[Type["ClientResponse"]] = None,
proxy: Optional[URL] = None,
proxy_auth: Optional[BasicAuth] = None,
Expand All @@ -282,16 +219,12 @@ def __init__(
trust_env: bool = False,
server_hostname: Optional[str] = None,
):
if loop is None:
loop = asyncio.get_event_loop()

match = _CONTAINS_CONTROL_CHAR_RE.search(method)
if match:
raise ValueError(
f"Method cannot contain non-token characters {method!r} "
"(found at least {match.group()!r})"
f"(found at least {match.group()!r})"
)

assert isinstance(url, URL), url
assert isinstance(proxy, (URL, type(None))), proxy
# FIXME: session is None in tests only, need to fix tests
Expand Down Expand Up @@ -548,7 +481,12 @@ def update_body_from_data(self, body: Any) -> None:
try:
body = payload.PAYLOAD_REGISTRY.get(body, disposition=None)
except payload.LookupError:
body = FormData(body)()
boundary = None
if CONTENT_TYPE in self.headers:
boundary = parse_mimetype(self.headers[CONTENT_TYPE]).parameters.get(
"boundary"
)
body = FormData(body, boundary=boundary)()

self.body = body

Expand All @@ -564,7 +502,7 @@ def update_body_from_data(self, body: Any) -> None:

# copy payload headers
assert body.headers
for (key, value) in body.headers.items():
for key, value in body.headers.items():
if key in self.headers:
continue
if key in self.skip_auto_headers:
Expand Down Expand Up @@ -752,7 +690,6 @@ async def _on_headers_request_sent(


class ClientResponse(HeadersMixin):

# Some of these attributes are None when created,
# but will be set by the start() method.
# As the end user will likely never see the None values, we cheat the types below.
Expand All @@ -761,7 +698,7 @@ class ClientResponse(HeadersMixin):
status: int = None # type: ignore[assignment] # Status-Code
reason: Optional[str] = None # Reason-Phrase

content: StreamReader = None # type: ignore[assignment] # Payload stream
content: StreamReader = None # type: ignore[assignment] # Payload stream
_headers: CIMultiDictProxy[str] = None # type: ignore[assignment]
_raw_headers: RawHeaders = None # type: ignore[assignment]

Expand All @@ -780,20 +717,21 @@ def __init__(
*,
writer: "asyncio.Task[None]",
continue100: Optional["asyncio.Future[bool]"],
timer: BaseTimerContext,
timer: Optional[BaseTimerContext],
request_info: RequestInfo,
traces: List["Trace"],
loop: asyncio.AbstractEventLoop,
session: "ClientSession",
) -> None:
assert isinstance(url, URL)
super().__init__()

self.method = method
self.cookies = SimpleCookie()

self._real_url = url
self._url = url.with_fragment(None)
self._body: Any = None
self._body: Optional[bytes] = None
self._writer: Optional[asyncio.Task[None]] = writer
self._continue = continue100 # None by default
self._closed = True
Expand Down Expand Up @@ -836,11 +774,6 @@ def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
def url(self) -> URL:
return self._url

@reify
def url_obj(self) -> URL:
warnings.warn("Deprecated, use .url #1654", DeprecationWarning, stacklevel=2)
return self._url

@reify
def real_url(self) -> URL:
return self._real_url
Expand Down Expand Up @@ -881,8 +814,9 @@ def __del__(self, _warnings: Any = warnings) -> None:
self._cleanup_writer()

if self._loop.get_debug():
kwargs = {"source": self}
_warnings.warn(f"Unclosed response {self!r}", ResourceWarning, **kwargs)
_warnings.warn(
f"Unclosed response {self!r}", ResourceWarning, source=self
)
context = {"client_response": self, "message": "Unclosed response"}
if self._source_traceback:
context["source_traceback"] = self._source_traceback
Expand Down Expand Up @@ -912,7 +846,7 @@ def connection(self) -> Optional["Connection"]:

@reify
def history(self) -> Tuple["ClientResponse", ...]:
"""A sequence of of responses, if redirects occurred."""
"""A sequence of responses, if redirects occurred."""
return self._history

@reify
Expand Down Expand Up @@ -1006,19 +940,14 @@ def _response_eof(self) -> None:
if self._closed:
return

if self._connection is not None:
# websocket, protocol could be None because
# connection could be detached
if (
self._connection.protocol is not None
and self._connection.protocol.upgraded
):
return

self._release_connection()
# protocol could be None because connection could be detached
protocol = self._connection and self._connection.protocol
if protocol is not None and protocol.upgraded:
return

self._closed = True
self._cleanup_writer()
self._release_connection()

@property
def closed(self) -> bool:
Expand All @@ -1029,7 +958,7 @@ def close(self) -> None:
self._notify_content()

self._closed = True
if self._loop is None or self._loop.is_closed():
if self._loop.is_closed():
return

self._cleanup_writer()
Expand Down Expand Up @@ -1089,7 +1018,8 @@ def _cleanup_writer(self) -> None:

def _notify_content(self) -> None:
content = self.content
if content and content.exception() is None:
# content can be None here, but the types are cheated elsewhere.
if content and content.exception() is None: # type: ignore[truthy-bool]
content.set_exception(ClientConnectionError("Connection closed"))
self._released = True

Expand All @@ -1113,8 +1043,10 @@ async def read(self) -> bytes:
elif self._released: # Response explicitly released
raise ClientConnectionError("Connection closed")

await self._wait_released() # Underlying connection released
return self._body # type: ignore[no-any-return]
protocol = self._connection and self._connection.protocol
if protocol is None or not protocol.upgraded:
await self._wait_released() # Underlying connection released
return self._body

def get_encoding(self) -> str:
ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower()
Expand Down Expand Up @@ -1147,9 +1079,7 @@ async def text(self, encoding: Optional[str] = None, errors: str = "strict") ->
if encoding is None:
encoding = self.get_encoding()

return self._body.decode( # type: ignore[no-any-return,union-attr]
encoding, errors=errors
)
return self._body.decode(encoding, errors=errors) # type: ignore[union-attr]

async def json(
self,
Expand All @@ -1163,25 +1093,21 @@ async def json(
await self.read()

if content_type:
ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower()
if not _is_expected_content_type(ctype, content_type):
if not is_expected_content_type(self.content_type, content_type):
raise ContentTypeError(
self.request_info,
self.history,
message=(
"Attempt to decode JSON with " "unexpected mimetype: %s" % ctype
"Attempt to decode JSON with "
"unexpected mimetype: %s" % self.content_type
),
headers=self.headers,
)

stripped = self._body.strip() # type: ignore[union-attr]
if not stripped:
return None

if encoding is None:
encoding = self.get_encoding()

return loads(stripped.decode(encoding))
return loads(self._body.decode(encoding)) # type: ignore[union-attr]

async def __aenter__(self) -> "ClientResponse":
return self
Expand Down
19 changes: 19 additions & 0 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,25 @@ async def handler(request):
assert 1 == len(client._session.connector._conns)


async def test_upgrade_connection_not_released_after_read(aiohttp_client: Any) -> None:
async def handler(request: web.Request) -> web.Response:
body = await request.read()
assert b"" == body
return web.Response(
status=101, headers={"Connection": "Upgrade", "Upgrade": "tcp"}
)

app = web.Application()
app.router.add_route("GET", "/", handler)

client = await aiohttp_client(app)

resp = await client.get("/")
await resp.read()
assert resp.connection is not None
assert not resp.closed


async def test_keepalive_server_force_close_connection(aiohttp_client: Any) -> None:
async def handler(request):
body = await request.read()
Expand Down

0 comments on commit 24434da

Please sign in to comment.