diff --git a/src/mount_efs/__init__.py b/src/mount_efs/__init__.py index 29fe7acd..eb6a0bc3 100755 --- a/src/mount_efs/__init__.py +++ b/src/mount_efs/__init__.py @@ -923,7 +923,7 @@ def get_tls_port_range(config): return lower_bound, upper_bound -def choose_tls_port(config, options): +def choose_tls_port(state_file_dir, fs_id, mountpoint, config, options): if "tlsport" in options: ports_to_try = [int(options["tlsport"])] else: @@ -938,13 +938,13 @@ def choose_tls_port(config, options): assert len(tls_ports) == len(ports_to_try) if "netns" not in options: - tls_port = find_tls_port_in_range(ports_to_try) + sock = find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try) else: with NetNS(nspath=options["netns"]): - tls_port = find_tls_port_in_range(ports_to_try) + sock = find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try) - if tls_port: - return tls_port + if sock: + return sock if "tlsport" in options: fatal_error( @@ -958,14 +958,18 @@ def choose_tls_port(config, options): ) -def find_tls_port_in_range(ports_to_try): +def find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try): sock = socket.socket() for tls_port in ports_to_try: + mount_filename = get_mount_specific_filename(fs_id, mountpoint, tls_port) + config_file = get_stunnel_config_filename(state_file_dir, mount_filename) + if os.access(config_file, os.R_OK): + logging.info("confifguration for port %s already exists, trying another port", tls_port) + continue try: logging.info("binding %s", tls_port) sock.bind(("localhost", tls_port)) - sock.close() - return tls_port + return sock except socket.error as e: logging.info(e) continue @@ -1210,9 +1214,7 @@ def write_stunnel_config_file( ) logging.debug("Writing stunnel configuration:\n%s", stunnel_config) - stunnel_config_file = os.path.join( - state_file_dir, "stunnel-config.%s" % mount_filename - ) + stunnel_config_file = get_stunnel_config_filename(state_file_dir, mount_filename) with open(stunnel_config_file, "w") as f: f.write(stunnel_config) @@ -1410,6 +1412,10 @@ def create_required_directory(config, directory): raise +def get_stunnel_config_filename(state_file_dir, mount_filename): + return os.path.join(state_file_dir, "stunnel-config.%s" % mount_filename) + + @contextmanager def bootstrap_tls( config, @@ -1421,7 +1427,8 @@ def bootstrap_tls( state_file_dir=STATE_FILE_DIR, fallback_ip_address=None, ): - tls_port = choose_tls_port(config, options) + sock = choose_tls_port(state_file_dir, fs_id, mountpoint, config, options) + tls_port = sock.getsockname()[1] # override the tlsport option so that we can later override the port the NFS client uses to connect to stunnel. # if the user has specified tlsport=X at the command line this will just re-set tlsport to X. options["tlsport"] = tls_port @@ -1497,6 +1504,8 @@ def bootstrap_tls( cert_details=cert_details, fallback_ip_address=fallback_ip_address, ) + # close the socket now, so the stunnel process can bind to the port + sock.close() tunnel_args = [_stunnel_bin(), stunnel_config_file] if "netns" in options: tunnel_args = ["nsenter", "--net=" + options["netns"]] + tunnel_args diff --git a/test/mount_efs_test/test_bootstrap_tls.py b/test/mount_efs_test/test_bootstrap_tls.py index 7aaaa19e..5fa59b79 100644 --- a/test/mount_efs_test/test_bootstrap_tls.py +++ b/test/mount_efs_test/test_bootstrap_tls.py @@ -169,6 +169,9 @@ def test_bootstrap_tls_non_default_port(mocker, tmpdir): popen_mock, write_config_mock = setup_mocks(mocker) mocker.patch("os.rename") state_file_dir = str(tmpdir) + fake_sock = MagicMock() + fake_sock.getsockname.return_value = ("localhost", 1000) + mocker.patch("socket.socket", return_value=fake_sock) tls_port = 1000 mocker.patch("mount_efs._stunnel_bin", return_value="/usr/bin/stunnel") diff --git a/test/mount_efs_test/test_choose_tls_port.py b/test/mount_efs_test/test_choose_tls_port.py index 2ad627b3..45d47ae6 100644 --- a/test/mount_efs_test/test_choose_tls_port.py +++ b/test/mount_efs_test/test_choose_tls_port.py @@ -8,6 +8,7 @@ import mount_efs import pytest +import random from mock import MagicMock from .. import utils @@ -19,6 +20,9 @@ DEFAULT_TLS_PORT_RANGE_LOW = 20049 DEFAULT_TLS_PORT_RANGE_HIGH = 20449 +FS_ID = "fs-deadbeef" +MOUNT_POINT = "/mnt" +STATE_FILE_DIR = "/tmp" def _get_config(): @@ -41,10 +45,14 @@ def _get_config(): def test_choose_tls_port_first_try(mocker): - mocker.patch("socket.socket", return_value=MagicMock()) + fake_sock = MagicMock() + tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH) + fake_sock.getsockname.return_value = ("localhost", tls_port) + mocker.patch("socket.socket", return_value=fake_sock) options = {} - tls_port = mount_efs.choose_tls_port(_get_config(), options) + sock = mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) + tls_port = sock.getsockname()[1] assert DEFAULT_TLS_PORT_RANGE_LOW <= tls_port <= DEFAULT_TLS_PORT_RANGE_HIGH @@ -52,11 +60,14 @@ def test_choose_tls_port_first_try(mocker): def test_choose_tls_port_second_try(mocker): bad_sock = MagicMock() bad_sock.bind.side_effect = [socket.error, None] + tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH) + bad_sock.getsockname.return_value = ("localhost", tls_port) options = {} mocker.patch("socket.socket", return_value=bad_sock) - tls_port = mount_efs.choose_tls_port(_get_config(), options) + sock = mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) + tls_port = sock.getsockname()[1] assert DEFAULT_TLS_PORT_RANGE_LOW <= tls_port <= DEFAULT_TLS_PORT_RANGE_HIGH assert 2 == bad_sock.bind.call_count @@ -70,7 +81,7 @@ def test_choose_tls_port_never_succeeds(mocker, capsys): mocker.patch("socket.socket", return_value=bad_sock) with pytest.raises(SystemExit) as ex: - mount_efs.choose_tls_port(_get_config(), options) + mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) assert 0 != ex.value.code @@ -84,10 +95,13 @@ def test_choose_tls_port_never_succeeds(mocker, capsys): def test_choose_tls_port_option_specified(mocker): - mocker.patch("socket.socket", return_value=MagicMock()) + fake_sock = MagicMock() + fake_sock.getsockname.return_value = ("localhost", 1000) + mocker.patch("socket.socket", return_value=fake_sock) options = {"tlsport": 1000} - tls_port = mount_efs.choose_tls_port(_get_config(), options) + sock = mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) + tls_port = sock.getsockname()[1] assert 1000 == tls_port @@ -100,7 +114,7 @@ def test_choose_tls_port_option_specified_unavailable(mocker, capsys): mocker.patch("socket.socket", return_value=bad_sock) with pytest.raises(SystemExit) as ex: - mount_efs.choose_tls_port(_get_config(), options) + mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) assert 0 != ex.value.code @@ -116,7 +130,7 @@ def test_choose_tls_port_under_netns(mocker, capsys): mocker.patch("socket.socket", return_value=MagicMock()) options = {"netns": "/proc/1000/ns/net"} - mount_efs.choose_tls_port(_get_config(), options) + mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) utils.assert_called(setns_mock) @@ -129,3 +143,21 @@ def test_verify_tls_port(mocker): result = mount_efs.verify_tlsport_can_be_connected(1000) assert result is True assert 2 == sock.connect.call_count + +def test_choose_tls_port_already_configured(mocker, capsys): + fake_sock = MagicMock() + tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH) + fake_sock.getsockname.return_value = ("localhost", tls_port) + mocker.patch("socket.socket", return_value=fake_sock) + access_mock = mocker.patch("os.access", return_value=True) + options = {} + + with pytest.raises(SystemExit) as ex: + mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) + + assert 0 != ex.value.code + + out, err = capsys.readouterr() + assert "Failed to locate an available port" in err + + utils.assert_called_n_times(access_mock, DEFAULT_TLS_PORT_RANGE_HIGH - DEFAULT_TLS_PORT_RANGE_LOW)