From 612d060af84eb32c8d4c1adb2f3fdc34e51c4757 Mon Sep 17 00:00:00 2001 From: Tomas Smetana Date: Mon, 25 Apr 2022 11:12:20 +0200 Subject: [PATCH] Fix race in stunnel port selection --- src/mount_efs/__init__.py | 171 +++++++++++--------- test/mount_efs_test/test_bootstrap_tls.py | 3 + test/mount_efs_test/test_choose_tls_port.py | 45 +++++- 3 files changed, 132 insertions(+), 87 deletions(-) diff --git a/src/mount_efs/__init__.py b/src/mount_efs/__init__.py index d345b43..7200d2c 100755 --- a/src/mount_efs/__init__.py +++ b/src/mount_efs/__init__.py @@ -939,7 +939,7 @@ def get_tls_port_range(config): return lower_bound, upper_bound -def choose_tls_port(config, options): +def choose_tls_port_and_bind_sock(state_file_dir, fs_id, mountpoint, config, options): if "tlsport" in options: ports_to_try = [int(options["tlsport"])] else: @@ -954,13 +954,14 @@ def choose_tls_port(config, options): assert len(tls_ports) == len(ports_to_try) if "netns" not in options: - tls_port = find_tls_port_in_range(ports_to_try) + tls_port_sock = find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try) else: with NetNS(nspath=options["netns"]): - tls_port = find_tls_port_in_range(ports_to_try) + tls_port_sock = find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try) - if tls_port: - return tls_port + if tls_port_sock: + tls_port = tls_port_sock.getsockname()[1] + return tls_port_sock, tls_port if "tlsport" in options: fatal_error( @@ -974,14 +975,18 @@ def choose_tls_port(config, options): ) -def find_tls_port_in_range(ports_to_try): +def find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try): sock = socket.socket() for tls_port in ports_to_try: + mount_filename = get_mount_specific_filename(fs_id, mountpoint, tls_port) + config_file = get_stunnel_config_filename(state_file_dir, mount_filename) + if os.access(config_file, os.R_OK): + logging.info("confifguration for port %s already exists, trying another port", tls_port) + continue try: logging.info("binding %s", tls_port) sock.bind(("localhost", tls_port)) - sock.close() - return tls_port + return sock except socket.error as e: logging.info(e) continue @@ -1262,9 +1267,7 @@ def write_stunnel_config_file( ) logging.debug("Writing stunnel configuration:\n%s", stunnel_config) - stunnel_config_file = os.path.join( - state_file_dir, "stunnel-config.%s" % mount_filename - ) + stunnel_config_file = get_stunnel_config_filename(state_file_dir, mount_filename) with open(stunnel_config_file, "w") as f: f.write(stunnel_config) @@ -1464,6 +1467,10 @@ def create_required_directory(config, directory): raise +def get_stunnel_config_filename(state_file_dir, mount_filename): + return os.path.join(state_file_dir, "stunnel-config.%s" % mount_filename) + + @contextmanager def bootstrap_tls( config, @@ -1475,82 +1482,88 @@ def bootstrap_tls( state_file_dir=STATE_FILE_DIR, fallback_ip_address=None, ): - tls_port = choose_tls_port(config, options) - # override the tlsport option so that we can later override the port the NFS client uses to connect to stunnel. - # if the user has specified tlsport=X at the command line this will just re-set tlsport to X. - options["tlsport"] = tls_port - - use_iam = "iam" in options - ap_id = options.get("accesspoint") - cert_details = {} - security_credentials = None - client_info = get_client_info(config) - region = get_target_region(config) - - if use_iam: - aws_creds_uri = options.get("awscredsuri") - if aws_creds_uri: - kwargs = {"aws_creds_uri": aws_creds_uri} - else: - kwargs = {"awsprofile": get_aws_profile(options, use_iam)} + sock, tls_port = choose_tls_port_and_bind_sock(state_file_dir, fs_id, mountpoint, config, options) + try: + # override the tlsport option so that we can later override the port the NFS client uses to connect to stunnel. + # if the user has specified tlsport=X at the command line this will just re-set tlsport to X. + options["tlsport"] = tls_port + + use_iam = "iam" in options + ap_id = options.get("accesspoint") + cert_details = {} + security_credentials = None + client_info = get_client_info(config) + region = get_target_region(config) + + if use_iam: + aws_creds_uri = options.get("awscredsuri") + if aws_creds_uri: + kwargs = {"aws_creds_uri": aws_creds_uri} + else: + kwargs = {"awsprofile": get_aws_profile(options, use_iam)} - security_credentials, credentials_source = get_aws_security_credentials( - config, use_iam, region, **kwargs - ) + security_credentials, credentials_source = get_aws_security_credentials( + config, use_iam, region, **kwargs + ) - if credentials_source: - cert_details["awsCredentialsMethod"] = credentials_source + if credentials_source: + cert_details["awsCredentialsMethod"] = credentials_source - if ap_id: - cert_details["accessPoint"] = ap_id + if ap_id: + cert_details["accessPoint"] = ap_id - # additional symbol appended to avoid naming collisions - cert_details["mountStateDir"] = ( - get_mount_specific_filename(fs_id, mountpoint, tls_port) + "+" - ) - # common name for certificate signing request is max 64 characters - cert_details["commonName"] = socket.gethostname()[0:64] - region = get_target_region(config) - cert_details["region"] = region - cert_details["certificateCreationTime"] = create_certificate( - config, - cert_details["mountStateDir"], - cert_details["commonName"], - cert_details["region"], - fs_id, - security_credentials, - ap_id, - client_info, - base_path=state_file_dir, - ) - cert_details["certificate"] = os.path.join( - state_file_dir, cert_details["mountStateDir"], "certificate.pem" - ) - cert_details["privateKey"] = get_private_key_path() - cert_details["fsId"] = fs_id + # additional symbol appended to avoid naming collisions + cert_details["mountStateDir"] = ( + get_mount_specific_filename(fs_id, mountpoint, tls_port) + "+" + ) + # common name for certificate signing request is max 64 characters + cert_details["commonName"] = socket.gethostname()[0:64] + region = get_target_region(config) + cert_details["region"] = region + cert_details["certificateCreationTime"] = create_certificate( + config, + cert_details["mountStateDir"], + cert_details["commonName"], + cert_details["region"], + fs_id, + security_credentials, + ap_id, + client_info, + base_path=state_file_dir, + ) + cert_details["certificate"] = os.path.join( + state_file_dir, cert_details["mountStateDir"], "certificate.pem" + ) + cert_details["privateKey"] = get_private_key_path() + cert_details["fsId"] = fs_id - start_watchdog(init_system) + start_watchdog(init_system) - if not os.path.exists(state_file_dir): - create_required_directory(config, state_file_dir) + if not os.path.exists(state_file_dir): + create_required_directory(config, state_file_dir) - verify_level = int(options.get("verify", DEFAULT_STUNNEL_VERIFY_LEVEL)) - ocsp_enabled = is_ocsp_enabled(config, options) + verify_level = int(options.get("verify", DEFAULT_STUNNEL_VERIFY_LEVEL)) + ocsp_enabled = is_ocsp_enabled(config, options) - stunnel_config_file = write_stunnel_config_file( - config, - state_file_dir, - fs_id, - mountpoint, - tls_port, - dns_name, - verify_level, - ocsp_enabled, - options, - region, - cert_details=cert_details, - fallback_ip_address=fallback_ip_address, - ) + stunnel_config_file = write_stunnel_config_file( + config, + state_file_dir, + fs_id, + mountpoint, + tls_port, + dns_name, + verify_level, + ocsp_enabled, + options, + region, + cert_details=cert_details, + fallback_ip_address=fallback_ip_address, + ) + except Exception as e: + logging.error("Error while creating the configuration file: %s" % e) + finally: + # close the socket now, so the stunnel process can bind to the port + sock.close() tunnel_args = [_stunnel_bin(), stunnel_config_file] if "netns" in options: tunnel_args = ["nsenter", "--net=" + options["netns"]] + tunnel_args diff --git a/test/mount_efs_test/test_bootstrap_tls.py b/test/mount_efs_test/test_bootstrap_tls.py index d044173..a045a42 100644 --- a/test/mount_efs_test/test_bootstrap_tls.py +++ b/test/mount_efs_test/test_bootstrap_tls.py @@ -170,6 +170,9 @@ def test_bootstrap_tls_non_default_port(mocker, tmpdir): popen_mock, write_config_mock = setup_mocks(mocker) mocker.patch("os.rename") state_file_dir = str(tmpdir) + fake_sock = MagicMock() + fake_sock.getsockname.return_value = ("localhost", 1000) + mocker.patch("socket.socket", return_value=fake_sock) tls_port = 1000 mocker.patch("mount_efs._stunnel_bin", return_value="/usr/bin/stunnel") diff --git a/test/mount_efs_test/test_choose_tls_port.py b/test/mount_efs_test/test_choose_tls_port.py index 800964e..4f5194c 100644 --- a/test/mount_efs_test/test_choose_tls_port.py +++ b/test/mount_efs_test/test_choose_tls_port.py @@ -5,6 +5,7 @@ # the License. import socket +import random from unittest.mock import MagicMock import pytest @@ -20,6 +21,9 @@ DEFAULT_TLS_PORT_RANGE_LOW = 20049 DEFAULT_TLS_PORT_RANGE_HIGH = 20449 +FS_ID = "fs-deadbeef" +MOUNT_POINT = "/mnt" +STATE_FILE_DIR = "/tmp" def _get_config(): @@ -42,10 +46,13 @@ def _get_config(): def test_choose_tls_port_first_try(mocker): - mocker.patch("socket.socket", return_value=MagicMock()) + fake_sock = MagicMock() + tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH) + fake_sock.getsockname.return_value = ("localhost", tls_port) + mocker.patch("socket.socket", return_value=fake_sock) options = {} - tls_port = mount_efs.choose_tls_port(_get_config(), options) + sock, tls_port = mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) assert DEFAULT_TLS_PORT_RANGE_LOW <= tls_port <= DEFAULT_TLS_PORT_RANGE_HIGH @@ -53,11 +60,13 @@ def test_choose_tls_port_first_try(mocker): def test_choose_tls_port_second_try(mocker): bad_sock = MagicMock() bad_sock.bind.side_effect = [socket.error, None] + tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH) + bad_sock.getsockname.return_value = ("localhost", tls_port) options = {} mocker.patch("socket.socket", return_value=bad_sock) - tls_port = mount_efs.choose_tls_port(_get_config(), options) + sock, tls_port = mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) assert DEFAULT_TLS_PORT_RANGE_LOW <= tls_port <= DEFAULT_TLS_PORT_RANGE_HIGH assert 2 == bad_sock.bind.call_count @@ -71,7 +80,7 @@ def test_choose_tls_port_never_succeeds(mocker, capsys): mocker.patch("socket.socket", return_value=bad_sock) with pytest.raises(SystemExit) as ex: - mount_efs.choose_tls_port(_get_config(), options) + mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) assert 0 != ex.value.code @@ -85,10 +94,12 @@ def test_choose_tls_port_never_succeeds(mocker, capsys): def test_choose_tls_port_option_specified(mocker): - mocker.patch("socket.socket", return_value=MagicMock()) + fake_sock = MagicMock() + fake_sock.getsockname.return_value = ("localhost", 1000) + mocker.patch("socket.socket", return_value=fake_sock) options = {"tlsport": 1000} - tls_port = mount_efs.choose_tls_port(_get_config(), options) + sock, tls_port = mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) assert 1000 == tls_port @@ -101,7 +112,7 @@ def test_choose_tls_port_option_specified_unavailable(mocker, capsys): mocker.patch("socket.socket", return_value=bad_sock) with pytest.raises(SystemExit) as ex: - mount_efs.choose_tls_port(_get_config(), options) + mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) assert 0 != ex.value.code @@ -117,7 +128,7 @@ def test_choose_tls_port_under_netns(mocker, capsys): mocker.patch("socket.socket", return_value=MagicMock()) options = {"netns": "/proc/1000/ns/net"} - mount_efs.choose_tls_port(_get_config(), options) + mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) utils.assert_called(setns_mock) @@ -130,3 +141,21 @@ def test_verify_tls_port(mocker): result = mount_efs.verify_tlsport_can_be_connected(1000) assert result is True assert 2 == sock.connect.call_count + +def test_choose_tls_port_already_configured(mocker, capsys): + fake_sock = MagicMock() + tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH) + fake_sock.getsockname.return_value = ("localhost", tls_port) + mocker.patch("socket.socket", return_value=fake_sock) + access_mock = mocker.patch("os.access", return_value=True) + options = {} + + with pytest.raises(SystemExit) as ex: + mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) + + assert 0 != ex.value.code + + out, err = capsys.readouterr() + assert "Failed to locate an available port" in err + + utils.assert_called_n_times(access_mock, DEFAULT_TLS_PORT_RANGE_HIGH - DEFAULT_TLS_PORT_RANGE_LOW)