diff --git a/src/mount_efs/__init__.py b/src/mount_efs/__init__.py index c6a88507..f943be10 100755 --- a/src/mount_efs/__init__.py +++ b/src/mount_efs/__init__.py @@ -929,7 +929,7 @@ def get_tls_port_range(config): return lower_bound, upper_bound -def choose_tls_port(config, options): +def choose_tls_port(state_file_dir, fs_id, mountpoint, config, options): if "tlsport" in options: ports_to_try = [int(options["tlsport"])] else: @@ -944,13 +944,13 @@ def choose_tls_port(config, options): assert len(tls_ports) == len(ports_to_try) if "netns" not in options: - tls_port = find_tls_port_in_range(ports_to_try) + sock = find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try) else: with NetNS(nspath=options["netns"]): - tls_port = find_tls_port_in_range(ports_to_try) + sock = find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try) - if tls_port: - return tls_port + if sock: + return sock if "tlsport" in options: fatal_error( @@ -964,14 +964,18 @@ def choose_tls_port(config, options): ) -def find_tls_port_in_range(ports_to_try): +def find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try): sock = socket.socket() for tls_port in ports_to_try: + mount_filename = get_mount_specific_filename(fs_id, mountpoint, tls_port) + config_file = get_stunnel_config_filename(state_file_dir, mount_filename) + if os.access(config_file, os.R_OK): + logging.info("confifguration for port %s already exists, trying another port", tls_port) + continue try: logging.info("binding %s", tls_port) sock.bind(("localhost", tls_port)) - sock.close() - return tls_port + return sock except socket.error as e: logging.info(e) continue @@ -1219,9 +1223,7 @@ def write_stunnel_config_file( ) logging.debug("Writing stunnel configuration:\n%s", stunnel_config) - stunnel_config_file = os.path.join( - state_file_dir, "stunnel-config.%s" % mount_filename - ) + stunnel_config_file = get_stunnel_config_filename(state_file_dir, mount_filename) with open(stunnel_config_file, "w") as f: f.write(stunnel_config) @@ -1419,6 +1421,10 @@ def create_required_directory(config, directory): raise +def get_stunnel_config_filename(state_file_dir, mount_filename): + return os.path.join(state_file_dir, "stunnel-config.%s" % mount_filename) + + @contextmanager def bootstrap_tls( config, @@ -1430,7 +1436,8 @@ def bootstrap_tls( state_file_dir=STATE_FILE_DIR, fallback_ip_address=None, ): - tls_port = choose_tls_port(config, options) + sock = choose_tls_port(state_file_dir, fs_id, mountpoint, config, options) + tls_port = sock.getsockname()[1] # override the tlsport option so that we can later override the port the NFS client uses to connect to stunnel. # if the user has specified tlsport=X at the command line this will just re-set tlsport to X. options["tlsport"] = tls_port @@ -1506,6 +1513,8 @@ def bootstrap_tls( cert_details=cert_details, fallback_ip_address=fallback_ip_address, ) + # close the socket now, so the stunnel process can bind to the port + sock.close() tunnel_args = [_stunnel_bin(), stunnel_config_file] if "netns" in options: tunnel_args = ["nsenter", "--net=" + options["netns"]] + tunnel_args diff --git a/test/mount_efs_test/test_bootstrap_tls.py b/test/mount_efs_test/test_bootstrap_tls.py index 18616f70..eb722e57 100644 --- a/test/mount_efs_test/test_bootstrap_tls.py +++ b/test/mount_efs_test/test_bootstrap_tls.py @@ -169,6 +169,9 @@ def test_bootstrap_tls_non_default_port(mocker, tmpdir): popen_mock, write_config_mock = setup_mocks(mocker) mocker.patch("os.rename") state_file_dir = str(tmpdir) + fake_sock = MagicMock() + fake_sock.getsockname.return_value = ("localhost", 1000) + mocker.patch("socket.socket", return_value=fake_sock) tls_port = 1000 mocker.patch("mount_efs._stunnel_bin", return_value="/usr/bin/stunnel") diff --git a/test/mount_efs_test/test_choose_tls_port.py b/test/mount_efs_test/test_choose_tls_port.py index 800964e2..6b868d74 100644 --- a/test/mount_efs_test/test_choose_tls_port.py +++ b/test/mount_efs_test/test_choose_tls_port.py @@ -5,6 +5,7 @@ # the License. import socket +import random from unittest.mock import MagicMock import pytest @@ -20,6 +21,9 @@ DEFAULT_TLS_PORT_RANGE_LOW = 20049 DEFAULT_TLS_PORT_RANGE_HIGH = 20449 +FS_ID = "fs-deadbeef" +MOUNT_POINT = "/mnt" +STATE_FILE_DIR = "/tmp" def _get_config(): @@ -42,10 +46,14 @@ def _get_config(): def test_choose_tls_port_first_try(mocker): - mocker.patch("socket.socket", return_value=MagicMock()) + fake_sock = MagicMock() + tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH) + fake_sock.getsockname.return_value = ("localhost", tls_port) + mocker.patch("socket.socket", return_value=fake_sock) options = {} - tls_port = mount_efs.choose_tls_port(_get_config(), options) + sock = mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) + tls_port = sock.getsockname()[1] assert DEFAULT_TLS_PORT_RANGE_LOW <= tls_port <= DEFAULT_TLS_PORT_RANGE_HIGH @@ -53,11 +61,14 @@ def test_choose_tls_port_first_try(mocker): def test_choose_tls_port_second_try(mocker): bad_sock = MagicMock() bad_sock.bind.side_effect = [socket.error, None] + tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH) + bad_sock.getsockname.return_value = ("localhost", tls_port) options = {} mocker.patch("socket.socket", return_value=bad_sock) - tls_port = mount_efs.choose_tls_port(_get_config(), options) + sock = mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) + tls_port = sock.getsockname()[1] assert DEFAULT_TLS_PORT_RANGE_LOW <= tls_port <= DEFAULT_TLS_PORT_RANGE_HIGH assert 2 == bad_sock.bind.call_count @@ -71,7 +82,7 @@ def test_choose_tls_port_never_succeeds(mocker, capsys): mocker.patch("socket.socket", return_value=bad_sock) with pytest.raises(SystemExit) as ex: - mount_efs.choose_tls_port(_get_config(), options) + mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) assert 0 != ex.value.code @@ -85,10 +96,13 @@ def test_choose_tls_port_never_succeeds(mocker, capsys): def test_choose_tls_port_option_specified(mocker): - mocker.patch("socket.socket", return_value=MagicMock()) + fake_sock = MagicMock() + fake_sock.getsockname.return_value = ("localhost", 1000) + mocker.patch("socket.socket", return_value=fake_sock) options = {"tlsport": 1000} - tls_port = mount_efs.choose_tls_port(_get_config(), options) + sock = mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) + tls_port = sock.getsockname()[1] assert 1000 == tls_port @@ -101,7 +115,7 @@ def test_choose_tls_port_option_specified_unavailable(mocker, capsys): mocker.patch("socket.socket", return_value=bad_sock) with pytest.raises(SystemExit) as ex: - mount_efs.choose_tls_port(_get_config(), options) + mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) assert 0 != ex.value.code @@ -117,7 +131,7 @@ def test_choose_tls_port_under_netns(mocker, capsys): mocker.patch("socket.socket", return_value=MagicMock()) options = {"netns": "/proc/1000/ns/net"} - mount_efs.choose_tls_port(_get_config(), options) + mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) utils.assert_called(setns_mock) @@ -130,3 +144,21 @@ def test_verify_tls_port(mocker): result = mount_efs.verify_tlsport_can_be_connected(1000) assert result is True assert 2 == sock.connect.call_count + +def test_choose_tls_port_already_configured(mocker, capsys): + fake_sock = MagicMock() + tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH) + fake_sock.getsockname.return_value = ("localhost", tls_port) + mocker.patch("socket.socket", return_value=fake_sock) + access_mock = mocker.patch("os.access", return_value=True) + options = {} + + with pytest.raises(SystemExit) as ex: + mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options) + + assert 0 != ex.value.code + + out, err = capsys.readouterr() + assert "Failed to locate an available port" in err + + utils.assert_called_n_times(access_mock, DEFAULT_TLS_PORT_RANGE_HIGH - DEFAULT_TLS_PORT_RANGE_LOW)