Skip to content

Commit

Permalink
fix unit socket implementation, most tests should be fine now (#696)
Browse files Browse the repository at this point in the history
configure custom socket path for mysql container, working around implicitly created volume folders being owned by root
we should probably just not use service containers for this to avoid having to do this patching
  • Loading branch information
Nothing4You committed Jan 28, 2022
1 parent 8fe7e53 commit ee5c48e
Show file tree
Hide file tree
Showing 13 changed files with 146 additions and 38 deletions.
17 changes: 16 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ jobs:
image: "${{ join(matrix.db, ':') }}"
ports:
- 3306:3306
volumes:
- "/tmp/run-${{ join(matrix.db, '-') }}/:/socket-mount/"
options: '--name=mysqld'
env:
MYSQL_ROOT_PASSWORD: rootpw
Expand Down Expand Up @@ -104,6 +106,19 @@ jobs:
docker container stop mysqld
docker container cp "${{ github.workspace }}/tests/ssl_resources/ssl" mysqld:/etc/mysql/ssl
docker container cp "${{ github.workspace }}/tests/ssl_resources/tls.cnf" mysqld:/etc/mysql/conf.d/aiomysql-tls.cnf
# use custom socket path
# we need to ensure that the socket path is writable for the user running the DB process in the container
sudo chmod 0777 /tmp/run-${{ join(matrix.db, '-') }}
# mysql 5.7 container overrides the socket path in /etc/mysql/mysql.conf.d/mysqld.cnf
if [ "${{ join(matrix.db, '-') }}" = "mysql-5.7" ]
then
docker container cp "${{ github.workspace }}/tests/ssl_resources/socket.cnf" mysqld:/etc/mysql/mysql.conf.d/zz-aiomysql-socket.cnf
else
docker container cp "${{ github.workspace }}/tests/ssl_resources/socket.cnf" mysqld:/etc/mysql/conf.d/aiomysql-socket.cnf
fi
docker container start mysqld
# ensure server is started up
Expand All @@ -119,7 +134,7 @@ jobs:
run: |
# timeout ensures a more or less clean stop by sending a KeyboardInterrupt which will still provide useful logs
timeout --preserve-status --signal=INT --verbose 5m \
pytest --color=yes --capture=no --verbosity 2 --cov-report term --cov-report xml --cov aiomysql ./tests --mysql-address "tcp-${{ join(matrix.db, '') }}=127.0.0.1:3306"
pytest --color=yes --capture=no --verbosity 2 --cov-report term --cov-report xml --cov aiomysql ./tests --mysql-unix-socket "unix-${{ join(matrix.db, '') }}=/tmp/run-${{ join(matrix.db, '-') }}/mysql.sock" --mysql-address "tcp-${{ join(matrix.db, '') }}=127.0.0.1:3306"
env:
PYTHONUNBUFFERED: 1
DB: '${{ matrix.db[0] }}'
Expand Down
1 change: 1 addition & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ To be included in 1.0.0 (unreleased)
* Ensure connections are properly closed before raising an OperationalError when the server connection is lost #660
* Ensure connections are properly closed before raising an InternalError when packet sequence numbers are out of sync #660
* Unix sockets are now internally considered secure, allowing sha256_password and caching_sha2_password auth methods to be used #695
* Test suite now also tests unix socket connections #696


0.0.22 (2021-11-14)
Expand Down
67 changes: 46 additions & 21 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ def pytest_generate_tests(metafunc):
mysql_addresses = []
ids = []

opt_mysql_unix_socket = \
list(metafunc.config.getoption("mysql_unix_socket"))
for i in range(len(opt_mysql_unix_socket)):
if "=" in opt_mysql_unix_socket[i]:
label, path = opt_mysql_unix_socket[i].split("=", 1)
mysql_addresses.append(path)
ids.append(label)
else:
mysql_addresses.append(opt_mysql_unix_socket[i])
ids.append("unix{}".format(i))

opt_mysql_address = list(metafunc.config.getoption("mysql_address"))
for i in range(len(opt_mysql_address)):
if "=" in opt_mysql_address[i]:
Expand Down Expand Up @@ -143,6 +154,12 @@ def pytest_addoption(parser):
default=[],
help="list of addresses to connect to: [name=]host[:port]",
)
parser.addoption(
"--mysql-unix-socket",
action="append",
default=[],
help="list of unix sockets to connect to: [name=]/path/to/socket",
)


@pytest.fixture
Expand Down Expand Up @@ -250,23 +267,30 @@ def ensure_mysql_version(request, mysql_image, mysql_tag):

@pytest.fixture(scope='session')
def mysql_server(mysql_image, mysql_tag, mysql_address):
ssl_directory = os.path.join(os.path.dirname(__file__),
'ssl_resources', 'ssl')
ca_file = os.path.join(ssl_directory, 'ca.pem')
unix_socket = type(mysql_address) is str

if not unix_socket:
ssl_directory = os.path.join(os.path.dirname(__file__),
'ssl_resources', 'ssl')
ca_file = os.path.join(ssl_directory, 'ca.pem')

ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
ctx.check_hostname = False
ctx.load_verify_locations(cafile=ca_file)
# ctx.verify_mode = ssl.CERT_NONE
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
ctx.check_hostname = False
ctx.load_verify_locations(cafile=ca_file)
# ctx.verify_mode = ssl.CERT_NONE

server_params = {
'host': mysql_address[0],
'port': mysql_address[1],
'user': 'root',
'password': os.environ.get("MYSQL_ROOT_PASSWORD"),
'ssl': ctx,
}

if unix_socket:
server_params["unix_socket"] = mysql_address
else:
server_params["host"] = mysql_address[0]
server_params["port"] = mysql_address[1]
server_params["ssl"] = ctx

try:
connection = pymysql.connect(
db='mysql',
Expand All @@ -275,21 +299,22 @@ def mysql_server(mysql_image, mysql_tag, mysql_address):
**server_params)

with connection.cursor() as cursor:
cursor.execute("SHOW VARIABLES LIKE '%ssl%';")
if not unix_socket:
cursor.execute("SHOW VARIABLES LIKE '%ssl%';")

result = cursor.fetchall()
result = {item['Variable_name']:
item['Value'] for item in result}
result = cursor.fetchall()
result = {item['Variable_name']:
item['Value'] for item in result}

assert result['have_ssl'] == "YES", \
"SSL Not Enabled on MySQL"
assert result['have_ssl'] == "YES", \
"SSL Not Enabled on MySQL"

cursor.execute("SHOW STATUS LIKE 'Ssl_version%'")
cursor.execute("SHOW STATUS LIKE 'Ssl_version%'")

result = cursor.fetchone()
# As we connected with TLS, it should start with that :D
assert result['Value'].startswith('TLS'), \
"Not connected to the database with TLS"
result = cursor.fetchone()
# As we connected with TLS, it should start with that :D
assert result['Value'].startswith('TLS'), \
"Not connected to the database with TLS"

# Drop possibly existing old databases
cursor.execute('DROP DATABASE IF EXISTS test_pymysql;')
Expand Down
File renamed without changes.
16 changes: 16 additions & 0 deletions tests/fixtures/my.cnf.unix.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#
# The MySQL database server configuration file.
#
[client]
user = {user}
socket = {unix_socket}
password = {password}
database = {db}
default-character-set = utf8

[client_with_unix_socket]
user = {user}
socket = {unix_socket}
password = {password}
database = {db}
default-character-set = utf8
11 changes: 9 additions & 2 deletions tests/sa/test_sa_compiled_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@
@pytest.fixture()
def make_engine(mysql_params, connection):
async def _make_engine(**kwargs):
if "unix_socket" in mysql_params:
conn_args = {"unix_socket": mysql_params["unix_socket"]}
else:
conn_args = {
"host": mysql_params['host'],
"port": mysql_params['port'],
}

return (await sa.create_engine(db=mysql_params['db'],
user=mysql_params['user'],
password=mysql_params['password'],
host=mysql_params['host'],
port=mysql_params['port'],
minsize=10,
**conn_args,
**kwargs))

return _make_engine
Expand Down
11 changes: 9 additions & 2 deletions tests/sa/test_sa_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,19 @@
@pytest.fixture()
def make_engine(mysql_params, connection):
async def _make_engine(**kwargs):
if "unix_socket" in mysql_params:
conn_args = {"unix_socket": mysql_params["unix_socket"]}
else:
conn_args = {
"host": mysql_params['host'],
"port": mysql_params['port'],
}

return (await sa.create_engine(db=mysql_params['db'],
user=mysql_params['user'],
password=mysql_params['password'],
host=mysql_params['host'],
port=mysql_params['port'],
minsize=10,
**conn_args,
**kwargs))

return _make_engine
Expand Down
11 changes: 9 additions & 2 deletions tests/sa/test_sa_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@
@pytest.fixture()
def make_engine(connection, mysql_params):
async def _make_engine(**kwargs):
if "unix_socket" in mysql_params:
conn_args = {"unix_socket": mysql_params["unix_socket"]}
else:
conn_args = {
"host": mysql_params['host'],
"port": mysql_params['port'],
}

return (await sa.create_engine(db=mysql_params['db'],
user=mysql_params['user'],
password=mysql_params['password'],
host=mysql_params['host'],
port=mysql_params['port'],
minsize=10,
**conn_args,
**kwargs))
return _make_engine

Expand Down
2 changes: 2 additions & 0 deletions tests/ssl_resources/socket.cnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[mysqld]
socket = /socket-mount/mysql.sock
29 changes: 22 additions & 7 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
@pytest.fixture()
def fill_my_cnf(mysql_params):
tests_root = os.path.abspath(os.path.dirname(__file__))
path1 = os.path.join(tests_root, 'fixtures/my.cnf.tmpl')

if "unix_socket" in mysql_params:
tmpl_path = "fixtures/my.cnf.unix.tmpl"
else:
tmpl_path = "fixtures/my.cnf.tcp.tmpl"

path1 = os.path.join(tests_root, tmpl_path)
path2 = os.path.join(tests_root, 'fixtures/my.cnf')
with open(path1) as f1:
tmpl = f1.read()
Expand All @@ -31,8 +37,11 @@ async def test_config_file(fill_my_cnf, connection_creator, mysql_params):
path = os.path.join(tests_root, 'fixtures/my.cnf')
conn = await connection_creator(read_default_file=path)

assert conn.host == mysql_params['host']
assert conn.port == mysql_params['port']
if "unix_socket" in mysql_params:
assert conn.unix_socket == mysql_params["unix_socket"]
else:
assert conn.host == mysql_params['host']
assert conn.port == mysql_params['port']
assert conn.user, mysql_params['user']

# make sure connection is working
Expand Down Expand Up @@ -167,12 +176,15 @@ async def test_connection_gone_away(connection_creator):


@pytest.mark.run_loop
async def test_connection_info_methods(connection_creator):
async def test_connection_info_methods(connection_creator, mysql_params):
conn = await connection_creator()
# trhead id is int
assert isinstance(conn.thread_id(), int)
assert conn.character_set_name() in ('latin1', 'utf8mb4')
assert str(conn.port) in conn.get_host_info()
if "unix_socket" in mysql_params:
assert mysql_params["unix_socket"] in conn.get_host_info()
else:
assert str(conn.port) in conn.get_host_info()
assert isinstance(conn.get_server_info(), str)
# protocol id is int
assert isinstance(conn.get_proto_info(), int)
Expand Down Expand Up @@ -200,8 +212,11 @@ async def test_connection_ping(connection_creator):
@pytest.mark.run_loop
async def test_connection_properties(connection_creator, mysql_params):
conn = await connection_creator()
assert conn.host == mysql_params['host']
assert conn.port == mysql_params['port']
if "unix_socket" in mysql_params:
assert conn.unix_socket == mysql_params["unix_socket"]
else:
assert conn.host == mysql_params['host']
assert conn.port == mysql_params['port']
assert conn.user == mysql_params['user']
assert conn.db == mysql_params['db']
assert conn.echo is False
Expand Down
2 changes: 1 addition & 1 deletion tests/test_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ async def test_issue_17(connection, connection_creator, mysql_params):
async def test_issue_34(connection_creator):
try:
await connection_creator(host="localhost", port=1237,
user="root")
user="root", unix_socket=None)
pytest.fail()
except aiomysql.OperationalError as e:
assert 2003 == e.args[0]
Expand Down
7 changes: 7 additions & 0 deletions tests/test_sha_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ async def test_sha256_nopw(mysql_server, loop):
@pytest.mark.mysql_version('mysql', '8.0')
@pytest.mark.run_loop
async def test_sha256_pw(mysql_server, loop):
# https://dev.mysql.com/doc/refman/8.0/en/sha256-pluggable-authentication.html
# Unlike caching_sha2_password, the sha256_password plugin does not treat
# shared-memory connections as secure, even though share-memory transport
# is secure by default.
if "unix_socket" in mysql_server['conn_params']:
pytest.skip("sha256_password is not supported on unix sockets")

connection_data = copy.copy(mysql_server['conn_params'])
connection_data['user'] = 'user_sha256'
connection_data['password'] = 'pass_sha256'
Expand Down
10 changes: 8 additions & 2 deletions tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@


@pytest.mark.run_loop
async def test_tls_connect(mysql_server, loop):
async def test_tls_connect(mysql_server, loop, mysql_params):
if "unix_socket" in mysql_params:
pytest.skip("TLS is not supported on unix sockets")

async with create_pool(**mysql_server['conn_params'],
loop=loop) as pool:
async with pool.get() as conn:
Expand Down Expand Up @@ -32,7 +35,10 @@ async def test_tls_connect(mysql_server, loop):

# MySQL will get you to renegotiate if sent a cleartext password
@pytest.mark.run_loop
async def test_auth_plugin_renegotiation(mysql_server, loop):
async def test_auth_plugin_renegotiation(mysql_server, loop, mysql_params):
if "unix_socket" in mysql_params:
pytest.skip("TLS is not supported on unix sockets")

async with create_pool(**mysql_server['conn_params'],
auth_plugin='mysql_clear_password',
loop=loop) as pool:
Expand Down

0 comments on commit ee5c48e

Please sign in to comment.