From 3f4555d9f556a0b94f1366c8c6b6e18a2edaf8a0 Mon Sep 17 00:00:00 2001 From: Richard Schwab Date: Sun, 10 Jul 2022 03:34:51 +0200 Subject: [PATCH] Add `implicit_tls` connect arg to support non-standard implicit TLS connections, such as Google Cloud SQL fixes #757 --- .github/workflows/ci-cd.yml | 40 +++++++++++-- CHANGES.txt | 2 + aiomysql/connection.py | 28 ++++++--- docs/connection.rst | 7 ++- tests/conftest.py | 63 ++++++++++++++++++--- tests/ssl_resources/haproxy.cfg | 34 +++++++++++ tests/ssl_resources/ssl/server-combined.pem | 46 +++++++++++++++ tests/test_sha_connection.py | 16 +++++- tests/test_ssl.py | 8 +++ 9 files changed, 221 insertions(+), 23 deletions(-) create mode 100644 tests/ssl_resources/haproxy.cfg create mode 100644 tests/ssl_resources/ssl/server-combined.pem diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index 9f01a54d..c6eacedf 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -449,6 +449,11 @@ jobs: options: '--name=mysqld' env: MYSQL_ROOT_PASSWORD: rootpw + haproxy: + image: haproxytech/haproxy-alpine:2.6 + ports: + - 13306:13306 + options: '--name=haproxy' steps: - name: Setup Python ${{ matrix.py }} @@ -569,6 +574,14 @@ jobs: # unfortunately we need this hacky workaround as GitHub Actions service containers can't reference data from our repo. - name: Prepare mysql run: | + # inject HAproxy configuration + docker container stop haproxy + + docker container cp "${{ github.workspace }}/tests/ssl_resources/haproxy.cfg" haproxy:/usr/local/etc/haproxy/haproxy.cfg + docker container cp "${{ github.workspace }}/tests/ssl_resources/ssl/server-combined.pem" haproxy:/usr/local/etc/haproxy/haproxy.pem + + docker container start haproxy + # ensure server is started up while : do @@ -598,11 +611,30 @@ jobs: mysql -h127.0.0.1 -uroot "-p$MYSQL_ROOT_PASSWORD" -e "SET GLOBAL local_infile=on" + # ensure we can login as root from any ip, as haproxy will not connect from localhost. + # at least on MySQL containers our account will be restricted to localhost by default. + mysql -h127.0.0.1 -uroot "-p$MYSQL_ROOT_PASSWORD" -e "DROP USER 'root'@'%'" + mysql -h127.0.0.1 -uroot "-p$MYSQL_ROOT_PASSWORD" -e "RENAME USER 'root'@'localhost' TO 'root'@'%'" + mysql -h127.0.0.1 -uroot "-p$MYSQL_ROOT_PASSWORD" -e "FLUSH PRIVILEGES" + - name: Run tests - run: | - # timeout ensures a more or less clean stop by sending a KeyboardInterrupt which will still provide useful logs - timeout --preserve-status --signal=INT --verbose 570s \ - pytest --capture=no --verbosity 2 --cov-report term --cov-report xml --cov aiomysql --cov tests ./tests --mysql-unix-socket "unix-${{ join(matrix.db, '') }}=/tmp/run-${{ join(matrix.db, '-') }}/mysql.sock" --mysql-address "tcp-${{ join(matrix.db, '') }}=127.0.0.1:3306" + # timeout ensures a more or less clean stop by sending a KeyboardInterrupt which will still provide useful logs + run: >- + timeout + --preserve-status + --signal=INT + --verbose 570s + pytest + --capture=no + --verbosity 2 + --cov-report term + --cov-report xml + --cov aiomysql + --cov tests + ./tests + --mysql-unix-socket "unix-${{ join(matrix.db, '') }}=/tmp/run-${{ join(matrix.db, '-') }}/mysql.sock" + --mysql-address "tcp-${{ join(matrix.db, '') }}=127.0.0.1:3306" + --mysql-address-tls "tls-${{ join(matrix.db, '') }}=127.0.0.1:13306" env: PYTHONUNBUFFERED: 1 timeout-minutes: 10 diff --git a/CHANGES.txt b/CHANGES.txt index fd1746bb..31a5cee9 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -8,6 +8,8 @@ next (unreleased) * Remove deprecated Pool.get #706 +* Add `implicit_tls` connect arg to support non-standard implicit TLS connections, such as Google Cloud SQL #757 + 0.1.1 (2022-05-08) ^^^^^^^^^^^^^^^^^^ diff --git a/aiomysql/connection.py b/aiomysql/connection.py index 022315a9..b4714a8a 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -53,7 +53,7 @@ def connect(host="localhost", user=None, password="", connect_timeout=None, read_default_group=None, autocommit=False, echo=False, local_infile=False, loop=None, ssl=None, auth_plugin='', - program_name='', server_public_key=None): + program_name='', server_public_key=None, implicit_tls=False): """See connections.Connection.__init__() for information about defaults.""" coro = _connect(host=host, user=user, password=password, db=db, @@ -66,7 +66,8 @@ def connect(host="localhost", user=None, password="", read_default_group=read_default_group, autocommit=autocommit, echo=echo, local_infile=local_infile, loop=loop, ssl=ssl, - auth_plugin=auth_plugin, program_name=program_name) + auth_plugin=auth_plugin, program_name=program_name, + implicit_tls=implicit_tls) return _ConnectionContextManager(coro) @@ -142,7 +143,7 @@ def __init__(self, host="localhost", user=None, password="", connect_timeout=None, read_default_group=None, autocommit=False, echo=False, local_infile=False, loop=None, ssl=None, auth_plugin='', - program_name='', server_public_key=None): + program_name='', server_public_key=None, implicit_tls=False): """ Establish a connection to the MySQL database. Accepts several arguments: @@ -184,6 +185,9 @@ def __init__(self, host="localhost", user=None, password="", handshaking with MySQL. (omitted by default) :param server_public_key: SHA256 authentication plugin public key value. + :param implicit_tls: Establish TLS immediately, skipping non-TLS + preamble before upgrading to TLS. + (default: False) :param loop: asyncio loop """ self._loop = loop or asyncio.get_event_loop() @@ -218,6 +222,7 @@ def __init__(self, host="localhost", user=None, password="", self._auth_plugin_used = "" self._secure = False self.server_public_key = server_public_key + self._implicit_tls = implicit_tls self.salt = None from . import __version__ @@ -241,7 +246,7 @@ def __init__(self, host="localhost", user=None, password="", self.use_unicode = use_unicode self._ssl_context = ssl - if ssl: + if ssl and not implicit_tls: client_flag |= CLIENT.SSL self._encoding = charset_by_name(self._charset).encoding @@ -536,7 +541,8 @@ async def _connect(self): self._next_seq_id = 0 - await self._get_server_information() + if not self._implicit_tls: + await self._get_server_information() await self._request_authentication() self.connected_time = self._loop.time() @@ -727,7 +733,8 @@ async def _execute_command(self, command, sql): async def _request_authentication(self): # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse - if int(self.server_version.split('.', 1)[0]) >= 5: + # FIXME: change this before merge + if self._implicit_tls or int(self.server_version.split('.', 1)[0]) >= 5: self.client_flag |= CLIENT.MULTI_RESULTS if self.user is None: @@ -737,8 +744,10 @@ async def _request_authentication(self): data_init = struct.pack('` that connects to MySQL. @@ -93,6 +93,11 @@ Example:: ``sys.argv[0]`` is no longer passed by default :param server_public_key: SHA256 authenticaiton plugin public key value. :param loop: asyncio event loop instance or ``None`` for default one. + :param implicit_tls: Establish TLS immediately, skipping non-TLS + preamble before upgrading to TLS. + (default: False) + + .. versionadded:: 0.2 :returns: :class:`Connection` instance. diff --git a/tests/conftest.py b/tests/conftest.py index 5420b456..cb2f7136 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import gc import os import re +import socket import ssl import sys @@ -63,13 +64,26 @@ def pytest_generate_tests(metafunc): if ":" in addr: addr = addr.split(":", 1) - mysql_addresses.append((addr[0], int(addr[1]))) + mysql_addresses.append((addr[0], int(addr[1]), False)) else: - mysql_addresses.append((addr, 3306)) + mysql_addresses.append((addr, 3306, False)) + + opt_mysql_address_tls =\ + list(metafunc.config.getoption("mysql_address_tls")) + for i in range(len(opt_mysql_address_tls)): + if "=" in opt_mysql_address_tls[i]: + label, addr = opt_mysql_address_tls[i].split("=", 1) + ids.append(label) + else: + addr = opt_mysql_address_tls[i] + ids.append("tls{}".format(i)) + + addr = addr.split(":", 1) + mysql_addresses.append((addr[0], int(addr[1]), True)) # default to connecting to localhost if len(mysql_addresses) == 0: - mysql_addresses = [("127.0.0.1", 3306)] + mysql_addresses = [("127.0.0.1", 3306, False)] ids = ["tcp-local"] assert len(mysql_addresses) == len(set(mysql_addresses)), \ @@ -153,6 +167,12 @@ def pytest_addoption(parser): default=[], help="list of addresses to connect to: [name=]host[:port]", ) + parser.addoption( + "--mysql-address-tls", + action="append", + default=[], + help="list of addresses to connect to using implicit TLS: [name=]host:port", + ) parser.addoption( "--mysql-unix-socket", action="append", @@ -249,6 +269,7 @@ def _register_table(table_name): @pytest.fixture(scope='session') def mysql_server(mysql_address): unix_socket = type(mysql_address) is str + implicit_tls = not unix_socket and mysql_address[2] if not unix_socket: ssl_directory = os.path.join(os.path.dirname(__file__), @@ -270,14 +291,34 @@ def mysql_server(mysql_address): else: server_params["host"] = mysql_address[0] server_params["port"] = mysql_address[1] + + if not implicit_tls: server_params["ssl"] = ctx try: - connection = pymysql.connect( - db='mysql', - charset='utf8mb4', - cursorclass=pymysql.cursors.DictCursor, - **server_params) + if implicit_tls: + sock = ctx.wrap_socket( + socket.create_connection( + (server_params["host"], server_params["port"]), + ), + server_hostname=server_params["host"], + ) + connection = pymysql.Connection( + db='mysql', + charset='utf8mb4', + cursorclass=pymysql.cursors.DictCursor, + **server_params, + defer_connect=True, + ) + connection.connect(sock) + + else: + connection = pymysql.connect( + db='mysql', + charset='utf8mb4', + cursorclass=pymysql.cursors.DictCursor, + **server_params, + ) with connection.cursor() as cursor: cursor.execute("SELECT VERSION() AS version") @@ -297,7 +338,7 @@ def mysql_server(mysql_address): pytest.fail("Unable to determine database type from {!r}" .format(server_version_tuple)) - if not unix_socket: + if not unix_socket and not implicit_tls: cursor.execute("SHOW VARIABLES LIKE '%ssl%';") result = cursor.fetchall() @@ -352,6 +393,10 @@ def mysql_server(mysql_address): except Exception: pytest.fail("Cannot initialize MySQL environment") + if implicit_tls: + server_params["ssl"] = ctx + server_params["implicit_tls"] = implicit_tls + return { "conn_params": server_params, "server_version": server_version, diff --git a/tests/ssl_resources/haproxy.cfg b/tests/ssl_resources/haproxy.cfg new file mode 100644 index 00000000..ecd02953 --- /dev/null +++ b/tests/ssl_resources/haproxy.cfg @@ -0,0 +1,34 @@ +# partially based on https://github.com/haproxytech/haproxy-docker-alpine/blob/8df4bf3078a338759ae484ab26908f7a4ba9484e/2.6/haproxy.cfg + +global + log stdout format raw local0 + + chroot /var/lib/haproxy + pidfile /var/run/haproxy.pid + maxconn 4000 + user haproxy + group haproxy + + ssl-default-bind-ciphers ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384 + ssl-default-bind-ciphersuites TLS_AES_128_GCM_SHA256:TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256 + ssl-default-bind-options prefer-client-ciphers no-sslv3 no-tlsv10 no-tlsv11 no-tls-tickets + ssl-default-server-ciphers ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-SHA384 + ssl-default-server-ciphersuites TLS_AES_128_GCM_SHA256:TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256 + ssl-default-server-options no-sslv3 no-tlsv10 no-tlsv11 no-tls-tickets + + +defaults + mode tcp + log global + timeout connect 10s + timeout client 1m + timeout server 1m + + +frontend tcp-13306-front + bind :13306 ssl crt /usr/local/etc/haproxy/haproxy.pem + default_backend tcp-13306-backend + + +backend tcp-13306-backend + server mysql mysqld:3306 diff --git a/tests/ssl_resources/ssl/server-combined.pem b/tests/ssl_resources/ssl/server-combined.pem new file mode 100644 index 00000000..bc496bcc --- /dev/null +++ b/tests/ssl_resources/ssl/server-combined.pem @@ -0,0 +1,46 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAn1LwhlVDVK3k0W3LgUENsPflTN0zzDa94HWVdV9coN07RCXP +5AWJ0j34c90tEWq/mPKXcQi765RiXgWnr2l8M4/LVvOUv95r+QRbRu9ceWwf3bnh +uMIBGhBIblySEB7PPZ8xoIIAm4WQrg6dpasKbV9lPz6/WCQ20rpKN6adBLzIu6n5 +zcs62WPYC6lJE3dUB/mzVkFuRgzCvX2w17xpnv8cg+DiyxB8uwnPishgc0jhaizV +aZqaQuDOY3/WkzNrOe3xGgxPraU+BPCynIl2CgFtN5/qX6BAn9SZ3rJRpbZoW7wD +Ri7gOJXzq/yOr+X9kLDUvCElf3e/mEfjvZOaFQIDAQABAoIBAFNubFP8LEEYut1M +4Kez+EZ22hXRNEG5XN9A095d7LS0hUefgWkH2W9GUmfiJ6qaOvEOAG4Jw9aOoqBX +18LMu2SI5VOIRJnhEKubM21HBSb0jw9eOqy0sz0Bz9wzD63vZFkBl0xVJ5pJbEUp +lDZgBhrWPL/MzQiMFkVtllXkIw+KNRIokV0HJn0VNUm+ORaDO0TnTAiqL7Wv553+ +lWvGaeI4NpMZOPtlgqym1neQmllkeB07pSEtTopO0iINwuTuDUU7IMleN4eOomfh +GwirEHUban8BDL5djckS0GrkUq9EuJbvjEikHAFuNwz7D1Sn3LsDYygD6pKWoXxh +8Ng/AZUCgYEAy1ma5xKvRZZ6QUFssZjtwbruxhj0j3lPA3t9VGlp4SQb3g6FgeOV +6dyJLuuFfSLT9ps+k8D1Er+v85OqO0IM8TR12rPjWMzSjBz/xHb+0uRbEoPwArfn +wcPj3NN6M+tcZH//djogrpphN7u+BudBG6YWzUOQjLwdStM52s8hFD8CgYEAyJM2 +M5AWPYL57CI1lqzN135aS6OFyG4N2+rtEWEGAmoeP+NoSGFmQPLTWBdT4ZOGl59/ +fSBrWRKted6/H9frjZuSdsGXFMb7e71DDyYZq7tJbYEExc0a6BePINLCwLCIMKHj +PcoPGVsdQXfZK17+qzACwDNbEis/J3H7xcPv7KsCgYBO+SG7k/oV4HbiWPJJlsbf +ciXBMXfpMIeLJq5p1faUxV09RA59f1F9XXS5kCZrjtca8ve+kjWbbm56/mIiWWiF +VIZgxXQJzKIIYErEliIo7R6hdjQEGkAbdGROIqNW/pUHQt6Hn9OJe9M9vd/y9mTG +xB4e4ZqFzZjisl3JqJ+EKQKBgFyJboxDgb9HWj7TWZ32g9FT/hy/iM172PEJZe6K +sNcUVnhrVoVuSlrUrSULPivogEQb1hnIhz5FG7wKRGtQluByUhRwJF/1nbjtDK9E +iLtuYOYgjC8l/a/ujp46Hpf/2hV12v1655RvMQQvYwZbgWtBb0N1biLnyO9N6zbG +uz6ZAoGAXW3mhN9zbN6EgsIlLu+wCdPuC9Vs968gCT612E8Ijul1kiSODG3rbMoG +2FbjqZLahyX8vWVhX/m4xDqO7DXTwi81polfFuxbc/PimOLI4DKq62lWLMBDmUba +X8Bxal3FXLvdcNEplcXadqmxJeXMYnYsfC+MCQhe6im0bGyDOAg= +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIIC/jCCAeYCAQEwDQYJKoZIhvcNAQELBQAwRTELMAkGA1UEBhMCQVUxEzARBgNV +BAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0 +ZDAeFw0xODAzMjkyMTMzMTFaFw0yODAyMDUyMTMzMTFaMEUxCzAJBgNVBAYTAkFV +MRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRz +IFB0eSBMdGQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCfUvCGVUNU +reTRbcuBQQ2w9+VM3TPMNr3gdZV1X1yg3TtEJc/kBYnSPfhz3S0Rar+Y8pdxCLvr +lGJeBaevaXwzj8tW85S/3mv5BFtG71x5bB/dueG4wgEaEEhuXJIQHs89nzGgggCb +hZCuDp2lqwptX2U/Pr9YJDbSuko3pp0EvMi7qfnNyzrZY9gLqUkTd1QH+bNWQW5G +DMK9fbDXvGme/xyD4OLLEHy7Cc+KyGBzSOFqLNVpmppC4M5jf9aTM2s57fEaDE+t +pT4E8LKciXYKAW03n+pfoECf1JneslGltmhbvANGLuA4lfOr/I6v5f2QsNS8ISV/ +d7+YR+O9k5oVAgMBAAEwDQYJKoZIhvcNAQELBQADggEBACUWEAWfv3EOy8JmUbWA +jEytJIh/N55hfknninjOBPMv1U1BRixJwXVKlwA8o+8JiacriObWeje2rDUOt6zY +U5DnySQbTYJcJZ9jprqU7VXST7D9NvA0ueLclWTZcqIr/josyhK+l1YbezFYBf41 +JQ4PVzkNz9Of4e022qONnlEX0MbtFlcyPEK4yWyXLAhidPAV9QcOCy85vob0+3EE +hmRVVzcTv4Pbzgpee0ZORqozSLzZ3N6RvDyYIczqaytcbyvaQ7GuykE7XvIK5hz0 +EM8pwsvxSY1z2yNIw38M8ZOYk18LsEGkf/TyT6eQqymMMD9Qy8rOTsOLfY5eQCf7 +pKQ= +-----END CERTIFICATE----- diff --git a/tests/test_sha_connection.py b/tests/test_sha_connection.py index 47baa0a6..cf328a60 100644 --- a/tests/test_sha_connection.py +++ b/tests/test_sha_connection.py @@ -40,7 +40,7 @@ async def test_sha256_nopw(mysql_server, loop): async with create_pool(**connection_data, loop=loop) as pool: async with pool.acquire() as conn: - # User doesnt have any permissions to look at DBs + # User doesn't have any permissions to look at DBs # But as 8.0 will default to caching_sha2_password assert conn._auth_plugin_used == 'sha256_password' @@ -56,6 +56,13 @@ async def test_sha256_pw(mysql_server, loop): if "unix_socket" in mysql_server['conn_params']: pytest.skip("sha256_password is not supported on unix sockets") + # In our tests, our database server considers implicit TLS connections as + # non-TLS connections, as TLS is terminated before reaching it. + if mysql_server['conn_params'].get('implicit_tls', False): + pytest.skip( + "sha256_password is not supported on implicit TLS connections" + ) + connection_data = copy.copy(mysql_server['conn_params']) connection_data['user'] = 'user_sha256' connection_data['password'] = 'pass_sha256' @@ -88,6 +95,13 @@ async def test_cached_sha256_nopw(mysql_server, loop): async def test_cached_sha256_pw(mysql_server, loop): ensure_mysql_version(mysql_server) + # In our tests, our database server considers implicit TLS connections as + # non-TLS connections, as TLS is terminated before reaching it. + if mysql_server['conn_params'].get('implicit_tls', False): + pytest.skip( + "sha256_password is not supported on implicit TLS connections" + ) + connection_data = copy.copy(mysql_server['conn_params']) connection_data['user'] = 'user_caching_sha2' connection_data['password'] = 'pass_caching_sha2' diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 9cc45fce..3f3e2c4d 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -7,6 +7,10 @@ async def test_tls_connect(mysql_server, loop, mysql_params): if "unix_socket" in mysql_params: pytest.skip("TLS is not supported on unix sockets") + if mysql_params.get("implicit_tls", False): + pytest.skip( + "Explicit TLS is not supported on implicit TLS connections", + ) async with create_pool(**mysql_server['conn_params'], loop=loop) as pool: @@ -38,6 +42,10 @@ async def test_tls_connect(mysql_server, loop, mysql_params): async def test_auth_plugin_renegotiation(mysql_server, loop, mysql_params): if "unix_socket" in mysql_params: pytest.skip("TLS is not supported on unix sockets") + if mysql_params.get("implicit_tls", False): + pytest.skip( + "Explicit TLS is not supported on implicit TLS connections", + ) async with create_pool(**mysql_server['conn_params'], auth_plugin='mysql_clear_password',