Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race in stunnel port selection #129

Merged
merged 1 commit into from Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
171 changes: 92 additions & 79 deletions src/mount_efs/__init__.py
Expand Up @@ -939,7 +939,7 @@ def get_tls_port_range(config):
return lower_bound, upper_bound


def choose_tls_port(config, options):
def choose_tls_port_and_bind_sock(state_file_dir, fs_id, mountpoint, config, options):
if "tlsport" in options:
ports_to_try = [int(options["tlsport"])]
else:
Expand All @@ -954,13 +954,14 @@ def choose_tls_port(config, options):
assert len(tls_ports) == len(ports_to_try)

if "netns" not in options:
tls_port = find_tls_port_in_range(ports_to_try)
tls_port_sock = find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try)
else:
with NetNS(nspath=options["netns"]):
tls_port = find_tls_port_in_range(ports_to_try)
tls_port_sock = find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try)

if tls_port:
return tls_port
if tls_port_sock:
tls_port = tls_port_sock.getsockname()[1]
return tls_port_sock, tls_port

if "tlsport" in options:
fatal_error(
Expand All @@ -974,14 +975,18 @@ def choose_tls_port(config, options):
)


def find_tls_port_in_range(ports_to_try):
def find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try):
sock = socket.socket()
for tls_port in ports_to_try:
mount_filename = get_mount_specific_filename(fs_id, mountpoint, tls_port)
config_file = get_stunnel_config_filename(state_file_dir, mount_filename)
if os.access(config_file, os.R_OK):
logging.info("confifguration for port %s already exists, trying another port", tls_port)
continue
Comment on lines +981 to +985
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? Since if the port is already used the binding will fail anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's is not necessary but helps in the case the port biniding fails to distinguish whether we're not clashing with other processes. Makes debugging easier.

try:
logging.info("binding %s", tls_port)
sock.bind(("localhost", tls_port))
sock.close()
return tls_port
return sock
except socket.error as e:
logging.info(e)
continue
Expand Down Expand Up @@ -1262,9 +1267,7 @@ def write_stunnel_config_file(
)
logging.debug("Writing stunnel configuration:\n%s", stunnel_config)

stunnel_config_file = os.path.join(
state_file_dir, "stunnel-config.%s" % mount_filename
)
stunnel_config_file = get_stunnel_config_filename(state_file_dir, mount_filename)

with open(stunnel_config_file, "w") as f:
f.write(stunnel_config)
Expand Down Expand Up @@ -1464,6 +1467,10 @@ def create_required_directory(config, directory):
raise


def get_stunnel_config_filename(state_file_dir, mount_filename):
return os.path.join(state_file_dir, "stunnel-config.%s" % mount_filename)


@contextmanager
def bootstrap_tls(
config,
Expand All @@ -1475,82 +1482,88 @@ def bootstrap_tls(
state_file_dir=STATE_FILE_DIR,
fallback_ip_address=None,
):
tls_port = choose_tls_port(config, options)
# override the tlsport option so that we can later override the port the NFS client uses to connect to stunnel.
# if the user has specified tlsport=X at the command line this will just re-set tlsport to X.
options["tlsport"] = tls_port

use_iam = "iam" in options
ap_id = options.get("accesspoint")
cert_details = {}
security_credentials = None
client_info = get_client_info(config)
region = get_target_region(config)

if use_iam:
aws_creds_uri = options.get("awscredsuri")
if aws_creds_uri:
kwargs = {"aws_creds_uri": aws_creds_uri}
else:
kwargs = {"awsprofile": get_aws_profile(options, use_iam)}
sock, tls_port = choose_tls_port_and_bind_sock(state_file_dir, fs_id, mountpoint, config, options)
try:
# override the tlsport option so that we can later override the port the NFS client uses to connect to stunnel.
# if the user has specified tlsport=X at the command line this will just re-set tlsport to X.
options["tlsport"] = tls_port

use_iam = "iam" in options
ap_id = options.get("accesspoint")
cert_details = {}
security_credentials = None
client_info = get_client_info(config)
region = get_target_region(config)

if use_iam:
aws_creds_uri = options.get("awscredsuri")
if aws_creds_uri:
kwargs = {"aws_creds_uri": aws_creds_uri}
else:
kwargs = {"awsprofile": get_aws_profile(options, use_iam)}

security_credentials, credentials_source = get_aws_security_credentials(
config, use_iam, region, **kwargs
)
security_credentials, credentials_source = get_aws_security_credentials(
config, use_iam, region, **kwargs
)

if credentials_source:
cert_details["awsCredentialsMethod"] = credentials_source
if credentials_source:
cert_details["awsCredentialsMethod"] = credentials_source

if ap_id:
cert_details["accessPoint"] = ap_id
if ap_id:
cert_details["accessPoint"] = ap_id

# additional symbol appended to avoid naming collisions
cert_details["mountStateDir"] = (
get_mount_specific_filename(fs_id, mountpoint, tls_port) + "+"
)
# common name for certificate signing request is max 64 characters
cert_details["commonName"] = socket.gethostname()[0:64]
region = get_target_region(config)
cert_details["region"] = region
cert_details["certificateCreationTime"] = create_certificate(
config,
cert_details["mountStateDir"],
cert_details["commonName"],
cert_details["region"],
fs_id,
security_credentials,
ap_id,
client_info,
base_path=state_file_dir,
)
cert_details["certificate"] = os.path.join(
state_file_dir, cert_details["mountStateDir"], "certificate.pem"
)
cert_details["privateKey"] = get_private_key_path()
cert_details["fsId"] = fs_id
# additional symbol appended to avoid naming collisions
cert_details["mountStateDir"] = (
get_mount_specific_filename(fs_id, mountpoint, tls_port) + "+"
)
# common name for certificate signing request is max 64 characters
cert_details["commonName"] = socket.gethostname()[0:64]
region = get_target_region(config)
cert_details["region"] = region
cert_details["certificateCreationTime"] = create_certificate(
config,
cert_details["mountStateDir"],
cert_details["commonName"],
cert_details["region"],
fs_id,
security_credentials,
ap_id,
client_info,
base_path=state_file_dir,
)
cert_details["certificate"] = os.path.join(
state_file_dir, cert_details["mountStateDir"], "certificate.pem"
)
cert_details["privateKey"] = get_private_key_path()
cert_details["fsId"] = fs_id

start_watchdog(init_system)
start_watchdog(init_system)

if not os.path.exists(state_file_dir):
create_required_directory(config, state_file_dir)
if not os.path.exists(state_file_dir):
create_required_directory(config, state_file_dir)

verify_level = int(options.get("verify", DEFAULT_STUNNEL_VERIFY_LEVEL))
ocsp_enabled = is_ocsp_enabled(config, options)
verify_level = int(options.get("verify", DEFAULT_STUNNEL_VERIFY_LEVEL))
ocsp_enabled = is_ocsp_enabled(config, options)

stunnel_config_file = write_stunnel_config_file(
config,
state_file_dir,
fs_id,
mountpoint,
tls_port,
dns_name,
verify_level,
ocsp_enabled,
options,
region,
cert_details=cert_details,
fallback_ip_address=fallback_ip_address,
)
stunnel_config_file = write_stunnel_config_file(
config,
state_file_dir,
fs_id,
mountpoint,
tls_port,
dns_name,
verify_level,
ocsp_enabled,
options,
region,
cert_details=cert_details,
fallback_ip_address=fallback_ip_address,
)
except Exception as e:
logging.error("Error while creating the configuration file: %s" % e)
finally:
# close the socket now, so the stunnel process can bind to the port
sock.close()
tunnel_args = [_stunnel_bin(), stunnel_config_file]
if "netns" in options:
tunnel_args = ["nsenter", "--net=" + options["netns"]] + tunnel_args
Expand Down
3 changes: 3 additions & 0 deletions test/mount_efs_test/test_bootstrap_tls.py
Expand Up @@ -170,6 +170,9 @@ def test_bootstrap_tls_non_default_port(mocker, tmpdir):
popen_mock, write_config_mock = setup_mocks(mocker)
mocker.patch("os.rename")
state_file_dir = str(tmpdir)
fake_sock = MagicMock()
fake_sock.getsockname.return_value = ("localhost", 1000)
mocker.patch("socket.socket", return_value=fake_sock)

tls_port = 1000
mocker.patch("mount_efs._stunnel_bin", return_value="/usr/bin/stunnel")
Expand Down
45 changes: 37 additions & 8 deletions test/mount_efs_test/test_choose_tls_port.py
Expand Up @@ -5,6 +5,7 @@
# the License.

import socket
import random
from unittest.mock import MagicMock

import pytest
Expand All @@ -20,6 +21,9 @@

DEFAULT_TLS_PORT_RANGE_LOW = 20049
DEFAULT_TLS_PORT_RANGE_HIGH = 20449
FS_ID = "fs-deadbeef"
MOUNT_POINT = "/mnt"
STATE_FILE_DIR = "/tmp"


def _get_config():
Expand All @@ -42,22 +46,27 @@ def _get_config():


def test_choose_tls_port_first_try(mocker):
mocker.patch("socket.socket", return_value=MagicMock())
fake_sock = MagicMock()
tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH)
fake_sock.getsockname.return_value = ("localhost", tls_port)
mocker.patch("socket.socket", return_value=fake_sock)
options = {}

tls_port = mount_efs.choose_tls_port(_get_config(), options)
sock, tls_port = mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)

assert DEFAULT_TLS_PORT_RANGE_LOW <= tls_port <= DEFAULT_TLS_PORT_RANGE_HIGH


def test_choose_tls_port_second_try(mocker):
bad_sock = MagicMock()
bad_sock.bind.side_effect = [socket.error, None]
tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH)
bad_sock.getsockname.return_value = ("localhost", tls_port)
options = {}

mocker.patch("socket.socket", return_value=bad_sock)

tls_port = mount_efs.choose_tls_port(_get_config(), options)
sock, tls_port = mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)

assert DEFAULT_TLS_PORT_RANGE_LOW <= tls_port <= DEFAULT_TLS_PORT_RANGE_HIGH
assert 2 == bad_sock.bind.call_count
Expand All @@ -71,7 +80,7 @@ def test_choose_tls_port_never_succeeds(mocker, capsys):
mocker.patch("socket.socket", return_value=bad_sock)

with pytest.raises(SystemExit) as ex:
mount_efs.choose_tls_port(_get_config(), options)
mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)

assert 0 != ex.value.code

Expand All @@ -85,10 +94,12 @@ def test_choose_tls_port_never_succeeds(mocker, capsys):


def test_choose_tls_port_option_specified(mocker):
mocker.patch("socket.socket", return_value=MagicMock())
fake_sock = MagicMock()
fake_sock.getsockname.return_value = ("localhost", 1000)
mocker.patch("socket.socket", return_value=fake_sock)
options = {"tlsport": 1000}

tls_port = mount_efs.choose_tls_port(_get_config(), options)
sock, tls_port = mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)

assert 1000 == tls_port

Expand All @@ -101,7 +112,7 @@ def test_choose_tls_port_option_specified_unavailable(mocker, capsys):
mocker.patch("socket.socket", return_value=bad_sock)

with pytest.raises(SystemExit) as ex:
mount_efs.choose_tls_port(_get_config(), options)
mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)

assert 0 != ex.value.code

Expand All @@ -117,7 +128,7 @@ def test_choose_tls_port_under_netns(mocker, capsys):
mocker.patch("socket.socket", return_value=MagicMock())
options = {"netns": "/proc/1000/ns/net"}

mount_efs.choose_tls_port(_get_config(), options)
mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)
utils.assert_called(setns_mock)


Expand All @@ -130,3 +141,21 @@ def test_verify_tls_port(mocker):
result = mount_efs.verify_tlsport_can_be_connected(1000)
assert result is True
assert 2 == sock.connect.call_count

def test_choose_tls_port_already_configured(mocker, capsys):
fake_sock = MagicMock()
tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH)
fake_sock.getsockname.return_value = ("localhost", tls_port)
mocker.patch("socket.socket", return_value=fake_sock)
access_mock = mocker.patch("os.access", return_value=True)
options = {}

with pytest.raises(SystemExit) as ex:
mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)

assert 0 != ex.value.code

out, err = capsys.readouterr()
assert "Failed to locate an available port" in err

utils.assert_called_n_times(access_mock, DEFAULT_TLS_PORT_RANGE_HIGH - DEFAULT_TLS_PORT_RANGE_LOW)