Skip to content

Commit

Permalink
added support for sqlalchemy default parameters aio-libs#455
Browse files Browse the repository at this point in the history
  • Loading branch information
Ганжин Михаил committed Dec 11, 2019
1 parent b16e5bd commit 160a4c5
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 0 deletions.
18 changes: 18 additions & 0 deletions aiomysql/sa/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,29 @@

try:
from sqlalchemy.dialects.mysql.pymysql import MySQLDialect_pymysql
from sqlalchemy.dialects.mysql.mysqldb import MySQLCompiler_mysqldb
except ImportError: # pragma: no cover
raise ImportError('aiomysql.sa requires sqlalchemy')


class MySQLCompiler_pymysql(MySQLCompiler_mysqldb):
def construct_params(self, params=None, _group_number=None, _check=True):
pd = super().construct_params(params, _group_number, _check)

for column in self.prefetch:
pd[column.key] = self._exec_default(column.default)

return pd

def _exec_default(self, default):
if default.is_callable:
return default.arg(self.dialect)
else:
return default.arg


_dialect = MySQLDialect_pymysql(paramstyle='pyformat')
_dialect.statement_compiler = MySQLCompiler_pymysql
_dialect.default_paramstyle = 'pyformat'


Expand Down
117 changes: 117 additions & 0 deletions tests/sa/test_sa_default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import asyncio
import datetime
import os
import unittest
from unittest import mock

import sqlalchemy as sa

import aiomysql.sa
from aiomysql import connect

meta = sa.MetaData()
table = sa.Table('sa_tbl', meta,
sa.Column('id', sa.Integer, nullable=False, primary_key=True),
sa.Column('string_length', sa.Integer,
default=sa.func.length('qwerty')),
sa.Column('number', sa.Integer, default=100, nullable=False),
sa.Column('description', sa.String(255), nullable=False,
default='default test'),
sa.Column('created_at', sa.DateTime,
default=datetime.datetime.now),
sa.Column('enabled', sa.Boolean, default=True))


class TestSAConnection(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
self.host = os.environ.get('MYSQL_HOST', 'localhost')
self.port = int(os.environ.get('MYSQL_PORT', 3306))
self.user = os.environ.get('MYSQL_USER', 'root')
self.db = os.environ.get('MYSQL_DB', 'test_pymysql')
self.password = os.environ.get('MYSQL_PASSWORD', '')

def tearDown(self):
self.loop.close()

async def connect(self, **kwargs):
conn = await connect(db=self.db,
user=self.user,
password=self.password,
host=self.host,
loop=self.loop,
port=self.port,
**kwargs)
await conn.autocommit(True)
cur = await conn.cursor()
await cur.execute("DROP TABLE IF EXISTS sa_tbl")
await cur.execute("CREATE TABLE sa_tbl "
"(id integer, string_length integer, number integer,"
" description VARCHAR(255), created_at DATETIME(6), "
"enabled TINYINT)")

await cur._connection.commit()
# await cur.close()
engine = mock.Mock()
engine.dialect = aiomysql.sa.engine._dialect
return aiomysql.sa.SAConnection(conn, engine)

def test_default_fields(self):
async def go():
conn = await self.connect()
await conn.execute(table.insert().values())

res = await conn.execute(table.select())
row = await res.fetchone()
self.assertEqual(row.string_length, 6)
self.assertEqual(row.number, 100)
self.assertEqual(row.description, 'default test')
self.assertEqual(row.enabled, True)
self.assertEqual(type(row.created_at), datetime.datetime)

self.loop.run_until_complete(go())

def test_default_fields_isnull(self):
async def go():
conn = await self.connect()
created_at = None
enabled = False
await conn.execute(table.insert().values(
enabled=enabled,
created_at=created_at,
))

res = await conn.execute(table.select())
row = await res.fetchone()
self.assertEqual(row.number, 100)
self.assertEqual(row.string_length, 6)
self.assertEqual(row.description, 'default test')
self.assertEqual(row.enabled, enabled)
self.assertEqual(row.created_at, created_at)

self.loop.run_until_complete(go())

def test_default_fields_edit(self):
async def go():
conn = await self.connect()
created_at = datetime.datetime.now()
description = 'new descr'
enabled = False
number = 111
await conn.execute(table.insert().values(
description=description,
enabled=enabled,
created_at=created_at,
number=number,
))

res = await conn.execute(table.select())
row = await res.fetchone()
self.assertEqual(row.number, number)
self.assertEqual(row.string_length, 6)
self.assertEqual(row.description, description)
self.assertEqual(row.enabled, enabled)
self.assertEqual(row.created_at, created_at)

self.loop.run_until_complete(go())

0 comments on commit 160a4c5

Please sign in to comment.