Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Even more async support #57

Merged
merged 6 commits into from
Jan 18, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 11 additions & 2 deletions aiomysql/sa/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .result import create_result_proxy
from .transaction import (RootTransaction, Transaction,
NestedTransaction, TwoPhaseTransaction)
from ..utils import _TransactionContextManager, _SAConnectionContextManager


class SAConnection:
Expand All @@ -23,7 +24,6 @@ def __init__(self, connection, engine):
self._engine = engine
self._dialect = engine.dialect

@asyncio.coroutine
def execute(self, query, *multiparams, **params):
"""Executes a SQL query with optional parameters.

Expand Down Expand Up @@ -61,6 +61,11 @@ def execute(self, query, *multiparams, **params):
execution.

"""
coro = self._execute(query, *multiparams, **params)
return _SAConnectionContextManager(coro)

@asyncio.coroutine
def _execute(self, query, *multiparams, **params):
cursor = yield from self._connection.cursor()
dp = _distill_params(multiparams, params)
if len(dp) > 1:
Expand Down Expand Up @@ -124,7 +129,6 @@ def closed(self):
def connection(self):
return self._connection

@asyncio.coroutine
def begin(self):
"""Begin a transaction and return a transaction handle.

Expand Down Expand Up @@ -152,6 +156,11 @@ def begin(self):
.begin_twophase - use a two phase/XA transaction

"""
coro = self._begin()
return _TransactionContextManager(coro)

@asyncio.coroutine
def _begin(self):
if self._transaction is None:
self._transaction = RootTransaction(self)
yield from self._begin_impl()
Expand Down
4 changes: 3 additions & 1 deletion aiomysql/sa/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ def _prepare(self):
cursor = self._cursor
if cursor.description is not None:
self._metadata = ResultMetaData(self, cursor.description)
callback = lambda wr: asyncio.Task(cursor.close(), loop=loop)

def callback(wr):
asyncio.Task(cursor.close(), loop=loop)
self._weak = weakref.ref(self, callback)
else:
self._metadata = None
Expand Down
21 changes: 12 additions & 9 deletions aiomysql/sa/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio

from . import exc
from ..utils import PY_35


class Transaction(object):
Expand Down Expand Up @@ -86,16 +87,18 @@ def commit(self):
def _do_commit(self):
pass

@asyncio.coroutine
def __aenter__(self):
return self
if PY_35: # pragma: no branch
@asyncio.coroutine
def __aenter__(self):
return self

@asyncio.coroutine
def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
yield from self.commit()
else:
yield from self.rollback()
@asyncio.coroutine
def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type:
yield from self.rollback()
else:
if self._is_active:
yield from self.commit()


class RootTransaction(Transaction):
Expand Down
23 changes: 23 additions & 0 deletions aiomysql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,29 @@ def __aexit__(self, exc_type, exc, tb):
self._obj = None


class _SAConnectionContextManager(_ContextManager):

if PY_35: # pragma: no branch
@asyncio.coroutine
def __aiter__(self):
result = yield from self._coro
return result


class _TransactionContextManager(_ContextManager):

if PY_35: # pragma: no branch

@asyncio.coroutine
def __aexit__(self, exc_type, exc, tb):
if exc_type:
yield from self._obj.rollback()
else:
if self._obj.is_active:
yield from self._obj.commit()
self._obj = None


class _PoolAcquireContextManager(_ContextManager):

__slots__ = ('_coro', '_conn', '_pool')
Expand Down
68 changes: 65 additions & 3 deletions tests/pep492/test_async_with.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ async def go():
async with conn:
await self._prepare(conn.connection)
ret = []
async for i in (await conn.execute(tbl.select())):
async for i in conn.execute(tbl.select()):
ret.append(i)
assert [(1, 'a'), (2, 'b'), (3, 'c')] == ret
assert conn.closed
Expand Down Expand Up @@ -249,7 +249,7 @@ async def go():
await self._prepare(conn.connection)

ret = []
async for i in (await conn.execute(tbl.select())):
async for i in conn.execute(tbl.select()):
ret.append(i)
assert [(1, 'a'), (2, 'b'), (3, 'c')] == ret

Expand All @@ -264,8 +264,70 @@ async def go():
await self._prepare(conn.connection)

ret = []
async for i in (await conn.execute(tbl.select())):
async for i in conn.execute(tbl.select()):
ret.append(i)
assert [(1, 'a'), (2, 'b'), (3, 'c')] == ret

self.loop.run_until_complete(go())

def test_transaction_context_manager(self):
async def go():
kw = self._conn_kw()
async with sa.create_engine(**kw) as engine:
async with engine.acquire() as conn:
await self._prepare(conn.connection)
async with conn.begin() as tr:
async with conn.execute(tbl.select()) as cursor:
ret = []
async for i in conn.execute(tbl.select()):
ret.append(i)
assert [(1, 'a'), (2, 'b'), (3, 'c')] == ret
assert cursor.closed
assert not tr.is_active

tr2 = await conn.begin()
async with tr2:
assert tr2.is_active
async with conn.execute('SELECT 1;') as cursor:
rec = await cursor.scalar()
assert rec == 1
cursor.close()
assert not tr2.is_active

assert conn.closed
self.loop.run_until_complete(go())

def test_transaction_context_manager_error(self):
async def go():
kw = self._conn_kw()
async with sa.create_engine(**kw) as engine:
async with engine.acquire() as conn:
with pytest.raises(RuntimeError) as ctx:
async with conn.begin() as tr:
assert tr.is_active
raise RuntimeError('boom')
assert str(ctx.value) == 'boom'
assert not tr.is_active
assert conn.closed
self.loop.run_until_complete(go())

def test_transaction_context_manager_commit_once(self):
async def go():
kw = self._conn_kw()
async with sa.create_engine(**kw) as engine:
async with engine.acquire() as conn:
async with conn.begin() as tr:
# check that in context manager we do not execute
# commit for second time. Two commits in row causes
# InvalidRequestError exception
await tr.commit()
assert not tr.is_active

tr2 = await conn.begin()
async with tr2:
assert tr2.is_active
# check for double commit one more time
await tr2.commit()
assert not tr2.is_active
assert conn.closed
self.loop.run_until_complete(go())