Skip to content

Commit

Permalink
Fix handling of user-defined types for sqlalchemy (#291)
Browse files Browse the repository at this point in the history
* fix handling of user-defined types for sqlalchemy #290

* rename test class
  • Loading branch information
vlanse authored and jettify committed May 20, 2018
1 parent 5b20a8a commit a923621
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 7 deletions.
6 changes: 6 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Changes
-------

0.0.15 (2018-05-13)
^^^^^^^^^^^^^^^^^^^

* Fixed handling of user-defined types for sqlalchemy #290


0.0.14 (2018-04-22)
^^^^^^^^^^^^^^^^^^^

Expand Down
2 changes: 1 addition & 1 deletion aiomysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from .cursors import Cursor, SSCursor, DictCursor, SSDictCursor
from .pool import create_pool, Pool

__version__ = '0.0.14'
__version__ = '0.0.15'

__all__ = [

Expand Down
9 changes: 8 additions & 1 deletion aiomysql/sa/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ async def _execute(self, query, *multiparams, **params):
elif dp:
dp = dp[0]

result_map = None

if isinstance(query, str):
await cursor.execute(query, dp or None)
elif isinstance(query, ClauseElement):
Expand All @@ -97,18 +99,23 @@ async def _execute(self, query, *multiparams, **params):
processed_parameters.append(params)
post_processed_params = self._dialect.execute_sequence_format(
processed_parameters)
result_map = compiled._result_columns

else:
if dp:
raise exc.ArgumentError("Don't mix sqlalchemy DDL clause "
"and execution with parameters")
post_processed_params = [compiled.construct_params()]
result_map = None
await cursor.execute(str(compiled), post_processed_params[0])
else:
raise exc.ArgumentError("sql statement should be str or "
"SQLAlchemy data "
"selection/modification clause")

ret = await create_result_proxy(self, cursor, self._dialect)
ret = await create_result_proxy(
self, cursor, self._dialect, result_map
)
self._weak_results.add(ret)
return ret

Expand Down
22 changes: 17 additions & 5 deletions aiomysql/sa/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from . import exc


async def create_result_proxy(connection, cursor, dialect):
result_proxy = ResultProxy(connection, cursor, dialect)
async def create_result_proxy(connection, cursor, dialect, result_map):
result_proxy = ResultProxy(connection, cursor, dialect, result_map)
await result_proxy._prepare()
return result_proxy

Expand Down Expand Up @@ -95,6 +95,12 @@ class ResultMetaData:
def __init__(self, result_proxy, metadata):
self._processors = processors = []

result_map = {}

if result_proxy._result_map:
result_map = {elem[0]: elem[3] for elem in
result_proxy._result_map}

# We do not strictly need to store the processor in the key mapping,
# though it is faster in the Python version (probably because of the
# saved attribute lookup self._processors)
Expand Down Expand Up @@ -124,8 +130,13 @@ def __init__(self, result_proxy, metadata):
# if dialect.requires_name_normalize:
# colname = dialect.normalize_name(colname)

name, obj, type_ = \
colname, None, typemap.get(coltype, sqltypes.NULLTYPE)
name, obj, type_ = (
colname,
None,
result_map.get(
colname,
typemap.get(coltype, sqltypes.NULLTYPE))
)

processor = type_._cached_result_processor(dialect, coltype)

Expand Down Expand Up @@ -223,13 +234,14 @@ class ResultProxy:
the originating SQL statement that produced this result set.
"""

def __init__(self, connection, cursor, dialect):
def __init__(self, connection, cursor, dialect, result_map):
self._dialect = dialect
self._closed = False
self._cursor = cursor
self._connection = connection
self._rowcount = cursor.rowcount
self._lastrowid = cursor.lastrowid
self._result_map = result_map

async def _prepare(self):
loop = self._connection.connection.loop
Expand Down
94 changes: 94 additions & 0 deletions tests/sa/test_sa_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import asyncio
from aiomysql import connect, sa
from enum import IntEnum

import os
import unittest
from unittest import mock

from sqlalchemy import MetaData, Table, Column, Integer, TypeDecorator


class UserDefinedEnum(IntEnum):
Value1 = 111
Value2 = 222


class IntEnumField(TypeDecorator):
impl = Integer

def __init__(self, enum_class, *arg, **kw):
TypeDecorator.__init__(self, *arg, **kw)
self.enum_class = enum_class

def process_bind_param(self, value, dialect):
""" From python to DB """
if value is None:
return None
elif not isinstance(value, self.enum_class):
return self.enum_class(value).value
else:
return value.value

def process_result_value(self, value, dialect):
""" From DB to Python """
if value is None:
return None

return self.enum_class(value)


meta = MetaData()
tbl = Table('sa_test_type_tbl', meta,
Column('id', Integer, nullable=False,
primary_key=True),
Column('val', IntEnumField(enum_class=UserDefinedEnum)))


class TestSATypes(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
self.host = os.environ.get('MYSQL_HOST', 'localhost')
self.port = int(os.environ.get('MYSQL_PORT', 3306))
self.user = os.environ.get('MYSQL_USER', 'root')
self.db = os.environ.get('MYSQL_DB', 'test_pymysql')
self.password = os.environ.get('MYSQL_PASSWORD', '')

def tearDown(self):
self.loop.close()

async def connect(self, **kwargs):
conn = await connect(db=self.db,
user=self.user,
password=self.password,
host=self.host,
loop=self.loop,
port=self.port,
**kwargs)
await conn.autocommit(True)
cur = await conn.cursor()
await cur.execute("DROP TABLE IF EXISTS sa_test_type_tbl")
await cur.execute("CREATE TABLE sa_test_type_tbl "
"(id serial, val bigint)")
await cur._connection.commit()
engine = mock.Mock()
engine.dialect = sa.engine._dialect
return sa.SAConnection(conn, engine)

def test_values(self):
async def go():
conn = await self.connect()

await conn.execute(tbl.insert().values(
val=UserDefinedEnum.Value1)
)
result = await conn.execute(tbl.select().where(
tbl.c.val == UserDefinedEnum.Value1)
)
data = await result.fetchone()
self.assertEqual(
data['val'], UserDefinedEnum.Value1
)

self.loop.run_until_complete(go())

0 comments on commit a923621

Please sign in to comment.