Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSL Support (plus mysql_clear_password plugin for RDS) #280

Merged
merged 3 commits into from
Apr 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
92 changes: 82 additions & 10 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def connect(host="localhost", user=None, password="",
client_flag=0, cursorclass=Cursor, init_command=None,
connect_timeout=None, read_default_group=None,
no_delay=None, autocommit=False, echo=False,
local_infile=False, loop=None):
local_infile=False, loop=None, ssl=None, auth_plugin=''):
"""See connections.Connection.__init__() for information about
defaults."""
coro = _connect(host=host, user=user, password=password, db=db,
Expand All @@ -68,7 +68,8 @@ def connect(host="localhost", user=None, password="",
connect_timeout=connect_timeout,
read_default_group=read_default_group,
no_delay=no_delay, autocommit=autocommit, echo=echo,
local_infile=local_infile, loop=loop)
local_infile=local_infile, loop=loop, ssl=ssl,
auth_plugin=auth_plugin)
return _ConnectionContextManager(coro)


Expand All @@ -93,7 +94,7 @@ def __init__(self, host="localhost", user=None, password="",
client_flag=0, cursorclass=Cursor, init_command=None,
connect_timeout=None, read_default_group=None,
no_delay=None, autocommit=False, echo=False,
local_infile=False, loop=None):
local_infile=False, loop=None, ssl=None, auth_plugin=''):
"""
Establish a connection to the MySQL database. Accepts several
arguments:
Expand Down Expand Up @@ -164,6 +165,9 @@ def __init__(self, host="localhost", user=None, password="",
self._no_delay = no_delay
self._echo = echo
self._last_usage = self._loop.time()
self._client_auth_plugin = auth_plugin
self._server_auth_plugin = ""
self._auth_plugin_used = ""

self._unix_socket = unix_socket
if charset:
Expand All @@ -176,6 +180,10 @@ def __init__(self, host="localhost", user=None, password="",
if use_unicode is not None:
self.use_unicode = use_unicode

self._ssl_context = ssl
if ssl:
client_flag |= CLIENT.SSL

self._encoding = charset_by_name(self._charset).encoding

if local_infile:
Expand Down Expand Up @@ -209,8 +217,6 @@ def __init__(self, host="localhost", user=None, password="",
# user
self._close_reason = None

self._auth_plugin_name = ""

@property
def host(self):
"""MySQL server IP address or name"""
Expand Down Expand Up @@ -663,6 +669,31 @@ def _request_authentication(self):
if self.user is None:
raise ValueError("Did not specify a username")

if self._ssl_context:
# capablities, max packet, charset
data = struct.pack('<IIB', self.client_flag, 16777216, 33)
data += b'\x00' * (32 - len(data))

self.write_packet(data)

# Stop sending events to data_received
self._writer.transport.pause_reading()

# Get the raw socket from the transport
raw_sock = self._writer.transport.get_extra_info('socket',
default=None)
if raw_sock is None:
raise RuntimeError("Transport does not expose socket instance")

# MySQL expects TLS negotiation to happen in the middle of a
# TCP connection not at start. Passing in a socket to
# open_connection will cause it to negotiate TLS on an existing
# connection not initiate a new one.
self._reader, self._writer = yield from asyncio.open_connection(
sock=raw_sock, ssl=self._ssl_context, loop=self._loop,
server_hostname=self._host
)

charset_id = charset_by_name(self.charset).id
if isinstance(self.user, str):
_user = self.user.encode(self.encoding)
Expand All @@ -673,8 +704,16 @@ def _request_authentication(self):
data = data_init + _user + b'\0'

authresp = b''
if self._auth_plugin_name in ('', 'mysql_native_password'):

auth_plugin = self._client_auth_plugin
if not self._client_auth_plugin:
# Contains the auth plugin from handshake
auth_plugin = self._server_auth_plugin

if auth_plugin in ('', 'mysql_native_password'):
authresp = _scramble(self._password.encode('latin1'), self.salt)
elif auth_plugin in ('', 'mysql_clear_password'):
authresp = self._password.encode('latin1') + b'\0'

if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
data += lenenc_int(len(authresp)) + authresp
Expand All @@ -693,11 +732,13 @@ def _request_authentication(self):
data += db + b'\0'

if self.server_capabilities & CLIENT.PLUGIN_AUTH:
name = self._auth_plugin_name
name = auth_plugin
if isinstance(name, str):
name = name.encode('ascii')
data += name + b'\0'

self._auth_plugin_used = auth_plugin

self.write_packet(data)
auth_packet = yield from self._read_packet()

Expand All @@ -710,14 +751,45 @@ def _request_authentication(self):
plugin_name = auth_packet.read_string()
if (self.server_capabilities & CLIENT.PLUGIN_AUTH and
plugin_name is not None):
auth_packet = self._process_auth(plugin_name, auth_packet)
auth_packet = yield from self._process_auth(
plugin_name, auth_packet)
else:
# send legacy handshake
data = _scramble_323(self._password.encode('latin1'),
self.salt) + b'\0'
self.write_packet(data)
auth_packet = yield from self._read_packet()

@asyncio.coroutine
def _process_auth(self, plugin_name, auth_packet):
if plugin_name == b"mysql_native_password":
# https://dev.mysql.com/doc/internals/en/
# secure-password-authentication.html#packet-Authentication::
# Native41
data = _scramble(self._password.encode('latin1'),
auth_packet.read_all())
elif plugin_name == b"mysql_old_password":
# https://dev.mysql.com/doc/internals/en/
# old-password-authentication.html
data = _scramble_323(self._password.encode('latin1'),
auth_packet.read_all()) + b'\0'
elif plugin_name == b"mysql_clear_password":
# https://dev.mysql.com/doc/internals/en/
# clear-text-authentication.html
data = self._password.encode('latin1') + b'\0'
else:
raise OperationalError(
2059, "Authentication plugin '%s' not configured" % plugin_name
)

self.write_packet(data)
pkt = yield from self._read_packet()
pkt.check_error()

self._auth_plugin_used = plugin_name

return pkt

# _mysql support
def thread_id(self):
return self.server_thread_id[0]
Expand Down Expand Up @@ -786,9 +858,9 @@ def _get_server_information(self):
server_end = data.find(b'\0', i)
if server_end < 0: # pragma: no cover - very specific upstream bug
# not found \0 and last field so take it all
self._auth_plugin_name = data[i:].decode('latin1')
self._server_auth_plugin = data[i:].decode('latin1')
else:
self._auth_plugin_name = data[i:server_end].decode('latin1')
self._server_auth_plugin = data[i:server_end].decode('latin1')

def get_transaction_status(self):
return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_IN_TRANS)
Expand Down
52 changes: 52 additions & 0 deletions tests/test_ssl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from aiomysql import create_pool

import pytest


@pytest.mark.run_loop
async def test_tls_connect(mysql_server, loop):
async with create_pool(**mysql_server['conn_params'],
loop=loop) as pool:
async with pool.get() as conn:
async with conn.cursor() as cur:
# Run simple command
await cur.execute("SHOW DATABASES;")
value = await cur.fetchall()

values = [item[0] for item in value]
# Spot check the answers, we should at least have mysql
# and information_schema
assert 'mysql' in values, \
'Could not find the "mysql" table'
assert 'information_schema' in values, \
'Could not find the "mysql" table'

# Check TLS variables
await cur.execute("SHOW STATUS LIKE '%Ssl_version%';")
value = await cur.fetchone()

# The context has TLS
assert value[1].startswith('TLS'), \
'Not connected to the database with TLS'


# MySQL will get you to renegotiate if sent a cleartext password
@pytest.mark.run_loop
async def test_auth_plugin_renegotiation(mysql_server, loop):
async with create_pool(**mysql_server['conn_params'],
auth_plugin='mysql_clear_password',
loop=loop) as pool:
async with pool.get() as conn:
async with conn.cursor() as cur:
# Run simple command
await cur.execute("SHOW DATABASES;")
value = await cur.fetchall()

assert len(value), 'No databases found'

assert conn._client_auth_plugin == 'mysql_clear_password', \
'Client did not try clear password auth'
assert conn._server_auth_plugin == 'mysql_native_password', \
'Server did not ask for native auth'
assert conn._auth_plugin_used == b'mysql_native_password', \
'Client did not renegotiate with native auth'