Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed connection cancellation in process of executing a query #79

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 25 additions & 4 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ def __init__(self, host="localhost", user=None, password="",
# asyncio StreamReader, StreamWriter
self._reader = None
self._writer = None
# If connection was closed for specific reason, we should show that to
# user
self._close_reason = None

@property
def host(self):
Expand Down Expand Up @@ -359,6 +362,7 @@ def cursor(self, cursor=None):
:returns: instance of cursor, by default :class:`Cursor`
:raises TypeError: cursor_class is not a subclass of Cursor.
"""
self._ensure_alive()
if cursor is not None and not issubclass(cursor, Cursor):
raise TypeError('Custom cursor must be subclass of Cursor')

Expand Down Expand Up @@ -514,6 +518,9 @@ def _read_packet(self, packet_type=MysqlPacket):
buff += recv_data
if bytes_to_read < MAX_PACKET_LEN:
break
except asyncio.CancelledError:
self._close_on_cancel()
raise
except (OSError, EOFError) as exc:
msg = "MySQL server has gone away (%s)"
raise OperationalError(2006, msg % (exc,)) from exc
Expand Down Expand Up @@ -563,8 +570,7 @@ def __aexit__(self, exc_type, exc_val, exc_tb):

@asyncio.coroutine
def _execute_command(self, command, sql):
if not self._writer:
raise InterfaceError("(0, 'Not connected')")
self._ensure_alive()

# If the last query was unbuffered, make sure it finishes before
# sending new commands
Expand Down Expand Up @@ -699,6 +705,19 @@ def get_transaction_status(self):
def get_server_info(self):
return self.server_version

# Just to always have consistent errors 2 helpers

def _close_on_cancel(self):
self.close()
self._close_reason = "Cancelled during execution"

def _ensure_alive(self):
if not self._writer:
if self._close_reason is None:
raise InterfaceError("(0, 'Not connected')")
else:
raise InterfaceError(self._close_reason)

if PY_341: # pragma: no branch
def __del__(self):
if self._writer:
Expand Down Expand Up @@ -955,8 +974,7 @@ def freader(chunk_size):
@asyncio.coroutine
def send_data(self):
"""Send data packets from the local file to the server"""
if not self.connection._writer:
raise InterfaceError("(0, '')")
self.connection._ensure_alive()
conn = self.connection

try:
Expand All @@ -968,6 +986,9 @@ def send_data(self):
if not chunk:
break
conn.write_packet(chunk)
except asyncio.CancelledError:
self.connection._close_on_cancel()
raise
finally:
# send the empty packet to signify we are done sending data
conn.write_packet(b"")
20 changes: 19 additions & 1 deletion tests/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from tests import base
from tests._testutils import run_until_complete

from aiomysql import ProgrammingError, Cursor
from aiomysql import ProgrammingError, Cursor, InterfaceError


class TestCursor(base.AIOPyMySQLTestCase):
Expand Down Expand Up @@ -286,3 +286,21 @@ def test_morgify(self):
"INSERT INTO tbl VALUES(2, 'b')",
"INSERT INTO tbl VALUES(3, 'c')"]
self.assertEqual(results, expected)

@run_until_complete
def test_execute_cancel(self):
conn = self.connections[0]
cur = yield from conn.cursor()
# Cancel a cursor in the middle of execution, before it could
# read even the first packet (SLEEP assures the timings)
task = self.loop.create_task(cur.execute(
"SELECT 1 as id, SLEEP(0.1) as xxx"))
yield from asyncio.sleep(0.05, loop=self.loop)
task.cancel()
try:
yield from task
except asyncio.CancelledError:
pass

with self.assertRaises(InterfaceError):
yield from conn.cursor()
31 changes: 31 additions & 0 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,3 +603,34 @@ def go():
yield from self._set_global_conn_timeout(28800)

self.loop.run_until_complete(go())

def test_cancelled_connection(self):
@asyncio.coroutine
def go():
pool = yield from self.create_pool(minsize=0, maxsize=1)

try:
with (yield from pool) as conn:
curs = yield from conn.cursor()
# Cancel a cursor in the middle of execution, before it
# could read even the first packet (SLEEP assures the
# timings)
task = self.loop.create_task(curs.execute(
"SELECT 1 as id, SLEEP(0.1) as xxx"))
yield from asyncio.sleep(0.05, loop=self.loop)
task.cancel()
yield from task
except asyncio.CancelledError:
pass

with (yield from pool) as conn:
cur2 = yield from conn.cursor()
res = yield from cur2.execute("SELECT 2 as value, 0 as xxx")
names = [x[0] for x in cur2.description]
# If we receive ["id", "xxx"] - we corrupted the connection
self.assertEqual(names, ["value", "xxx"])
res = yield from cur2.fetchall()
# If we receive [(1, 0)] - we retrieved old cursor's values
self.assertEqual(list(res), [(2, 0)])

self.loop.run_until_complete(go())
44 changes: 43 additions & 1 deletion tests/test_sscursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tests import base
from tests._testutils import run_until_complete

from aiomysql import ProgrammingError
from aiomysql import ProgrammingError, InterfaceError


class TestSSCursor(base.AIOPyMySQLTestCase):
Expand Down Expand Up @@ -142,3 +142,45 @@ def test_sscursor_scroll_errors(self):
yield from cursor.scroll(1, mode='absolute')
with self.assertRaises(ProgrammingError):
yield from cursor.scroll(2, mode='not_valid_mode')

@run_until_complete
def test_sscursor_cancel(self):
conn = self.connections[0]
cur = yield from conn.cursor(SSCursor)
# Prepare ALOT of data

yield from cur.execute('DROP TABLE IF EXISTS long_seq;')
yield from cur.execute(
""" CREATE TABLE long_seq (
id int(11)
)
""")

ids = [(x) for x in range(100000)]
yield from cur.executemany('INSERT INTO long_seq VALUES (%s)', ids)

# Will return several results. All we need at this point
big_str = "x" * 10000
yield from cur.execute(
"""SELECT '{}' as id FROM long_seq;
""".format(big_str))
first = yield from cur.fetchone()
self.assertEqual(first, (big_str,))

@asyncio.coroutine
def read_cursor():
while True:
res = yield from cur.fetchone()
if res is None:
break
task = self.loop.create_task(read_cursor())
yield from asyncio.sleep(0, loop=self.loop)
assert not task.done(), "Test failed to produce needed condition."
task.cancel()
try:
yield from task
except asyncio.CancelledError:
pass

with self.assertRaises(InterfaceError):
yield from conn.cursor(SSCursor)