diff --git a/setup.cfg b/setup.cfg index 8ec8fbbf6..46f4e3b99 100644 --- a/setup.cfg +++ b/setup.cfg @@ -82,7 +82,7 @@ plugins = [coverage:report] precision = 2 -fail_under = 97.79 +fail_under = 97.82 show_missing = true skip_covered = true exclude_lines = diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 7b0619cae..a5fe93d5b 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -10,6 +10,7 @@ try: import websockets + import websockets.client import websockets.exceptions from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory @@ -64,7 +65,6 @@ def app(scope): "connection": "upgrade", "sec-webSocket-version": "11", }, - timeout=5, ) if response.status_code == 426: # response.text == "" @@ -517,6 +517,40 @@ async def websocket_session(url): await websocket_session("ws://127.0.0.1:8000") +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_client_connection_lost(ws_protocol_cls, http_protocol_cls): + got_disconnect_event = False + + async def app(scope, receive, send): + nonlocal got_disconnect_event + while True: + message = await receive() + if message["type"] == "websocket.connect": + print("accepted") + await send({"type": "websocket.accept"}) + elif message["type"] == "websocket.disconnect": + break + + got_disconnect_event = True + + config = Config( + app=app, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + ws_ping_interval=0.0, + ) + async with run_server(config): + async with websockets.client.connect("ws://127.0.0.1:8000") as websocket: + websocket.transport.close() + await asyncio.sleep(0.1) + got_disconnect_event_before_shutdown = got_disconnect_event + + assert got_disconnect_event_before_shutdown is True + + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 627e32ef8..6e4f505f1 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -70,8 +70,7 @@ def connection_made(self, transport): self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) def connection_lost(self, exc): - if exc is not None: - self.queue.put_nowait({"type": "websocket.disconnect"}) + self.queue.put_nowait({"type": "websocket.disconnect"}) self.connections.remove(self) if self.logger.level <= TRACE_LOG_LEVEL: