diff --git a/lib/rucio/common/types.py b/lib/rucio/common/types.py index d085839695..fc4007a7b9 100644 --- a/lib/rucio/common/types.py +++ b/lib/rucio/common/types.py @@ -178,3 +178,11 @@ class RuleDict(TypedDict): class DIDDict(TypedDict): name: str scope: InternalScope + + +class IPDict(TypedDict): + ip: str + fqdn: str + site: str + latitude: Optional[float] + longitude: Optional[float] diff --git a/lib/rucio/common/utils.py b/lib/rucio/common/utils.py index 3f16b28561..a70ec90c71 100644 --- a/lib/rucio/common/utils.py +++ b/lib/rucio/common/utils.py @@ -37,12 +37,12 @@ import time import zlib from collections import OrderedDict -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence from enum import Enum from functools import partial, wraps from io import StringIO from itertools import zip_longest -from typing import TYPE_CHECKING, Optional, Type, TypeVar +from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse from uuid import uuid4 as uuid from xml.etree import ElementTree @@ -65,6 +65,10 @@ if TYPE_CHECKING: T = TypeVar('T') + from _typeshed import FileDescriptorOrPath + from sqlalchemy.orm import Session + + from rucio.common.types import IPDict, LoggerFunction # HTTP code dictionary. Not complete. Can be extended if needed. @@ -97,7 +101,7 @@ DATE_FORMAT = '%a, %d %b %Y %H:%M:%S UTC' -def invert_dict(d): +def invert_dict(d: dict[Any, Any]) -> dict[Any, Any]: """ Invert the dictionary. CAUTION: this function is not deterministic unless the input dictionary is one-to-one mapping. @@ -108,11 +112,11 @@ def invert_dict(d): return {value: key for key, value in d.items()} -def dids_as_dicts(did_list): +def dids_as_dicts(did_list: Iterable[str | dict[str, str]]) -> list[dict[str, str]]: """ Converts list of DIDs to list of dictionaries - :param did_list: list of DIDs as either "scope:name" or {"scope":"scope", "name","name"} - :returns: list of dictionaries {"scope":"scope", "name","name"} + :param did_list: list of DIDs as either "scope:name" or {"scope":"scope", "name":"name"} + :returns: list of dictionaries {"scope":"scope", "name":"name"} """ out = [] for did in did_list: @@ -128,7 +132,12 @@ def dids_as_dicts(did_list): return out -def build_url(url, path=None, params=None, doseq=False): +def build_url( + url: str, + path: Optional[str] = None, + params: Optional[str] = None, + doseq: bool = False +) -> str: """ utitily function to build an url for requests to the rucio system. @@ -147,7 +156,13 @@ def build_url(url, path=None, params=None, doseq=False): return complete_url -def all_oidc_req_claims_present(scope, audience, required_scope, required_audience, sepatator=" "): +def all_oidc_req_claims_present( + scope: str | list[str], + audience: str | list[str], + required_scope: str | list[str], + required_audience: str | list[str], + sepatator: str = " " +) -> bool: """ Checks if both of the following statements are true: - all items in required_scope are present in scope string @@ -206,11 +221,11 @@ def all_oidc_req_claims_present(scope, audience, required_scope, required_audien return False -def generate_uuid(): +def generate_uuid() -> str: return str(uuid()).replace('-', '').lower() -def generate_uuid_bytes(): +def generate_uuid_bytes() -> bytes: return uuid().bytes @@ -221,7 +236,7 @@ def generate_uuid_bytes(): CHECKSUM_KEY = 'supported_checksums' -def is_checksum_valid(checksum_name): +def is_checksum_valid(checksum_name: str) -> bool: """ A simple function to check wether a checksum algorithm is supported. Relies on GLOBALLY_SUPPORTED_CHECKSUMS to allow for expandability. @@ -233,20 +248,19 @@ def is_checksum_valid(checksum_name): return checksum_name in GLOBALLY_SUPPORTED_CHECKSUMS -def set_preferred_checksum(checksum_name): +def set_preferred_checksum(checksum_name: str) -> None: """ - A simple function to check wether a checksum algorithm is supported. - Relies on GLOBALLY_SUPPORTED_CHECKSUMS to allow for expandability. + If the input checksum name is valid, + set it as PREFERRED_CHECKSUM. :param checksum_name: The name of the checksum to be verified. - :returns: True if checksum_name is in GLOBALLY_SUPPORTED_CHECKSUMS list, False otherwise. """ if is_checksum_valid(checksum_name): global PREFERRED_CHECKSUM PREFERRED_CHECKSUM = checksum_name -def adler32(file): +def adler32(file: "FileDescriptorOrPath") -> str: """ An Adler-32 checksum is obtained by calculating two 16-bit checksums A and B and concatenating their bits into a 32-bit integer. A is the sum of all bytes in the @@ -293,7 +307,7 @@ def adler32(file): CHECKSUM_ALGO_DICT['adler32'] = adler32 -def md5(file): +def md5(file: "FileDescriptorOrPath") -> str: """ Runs the MD5 algorithm (RFC-1321) on the binary content of the file named file and returns the hexadecimal digest @@ -313,7 +327,7 @@ def md5(file): CHECKSUM_ALGO_DICT['md5'] = md5 -def sha256(file): +def sha256(file: "FileDescriptorOrPath") -> str: """ Runs the SHA256 algorithm on the binary content of the file named file and returns the hexadecimal digest @@ -330,7 +344,7 @@ def sha256(file): CHECKSUM_ALGO_DICT['sha256'] = sha256 -def crc32(file): +def crc32(file: FileDescriptorOrPath) -> str: """ Runs the CRC32 algorithm on the binary content of the file named file and returns the hexadecimal digest @@ -346,7 +360,7 @@ def crc32(file): CHECKSUM_ALGO_DICT['crc32'] = crc32 -def _next_pow2(num): +def _next_pow2(num: int) -> int: if not num: return 0 return math.ceil(math.log2(num)) @@ -392,7 +406,7 @@ def bittorrent_v2_piece_length(file_size: int) -> int: return 2 ** _bittorrent_v2_piece_length_pow2(file_size) -def bittorrent_v2_merkle_sha256(file) -> tuple[bytes, bytes, int]: +def bittorrent_v2_merkle_sha256(file: "FileDescriptorOrPath") -> tuple[bytes, bytes, int]: """ Compute the .torrent v2 hash tree for the given file. (http://www.bittorrent.org/beps/bep_0052.html) @@ -483,7 +497,7 @@ def _merkle_root(leafs: list[bytes], nb_levels: int, padding: bytes) -> bytes: return pieces_root, pieces_layers, piece_length -def merkle_sha256(file) -> str: +def merkle_sha256(file: "FileDescriptorOrPath") -> str: """ The root of the sha256 merkle hash tree with leaf size of 16 KiB. """ @@ -494,7 +508,7 @@ def merkle_sha256(file) -> str: CHECKSUM_ALGO_DICT['merkle_sha256'] = merkle_sha256 -def bencode(obj) -> bytes: +def bencode(obj: int | bytes | str | list | dict[bytes, Any]) -> bytes: """ Copied from the reference implementation of v2 bittorrent: http://bittorrent.org/beps/bep_0052_torrent_creator.py @@ -515,7 +529,7 @@ def bencode(obj) -> bytes: return b"d" + b"".join(map(bencode, itertools.chain(*items))) + b"e" else: raise ValueError("dict keys should be bytes " + str(obj.keys())) - raise ValueError("Allowed types: int, bytes, list, dict; not %s", type(obj)) + raise ValueError("Allowed types: int, bytes, str, list, dict; not %s", type(obj)) def construct_torrent( @@ -558,7 +572,7 @@ def construct_torrent( return torrent_id, torrent -def str_to_date(string): +def str_to_date(string: str) -> Optional[datetime.datetime]: """ Converts a RFC-1123 string to the corresponding datetime value. :param string: the RFC-1123 string to convert to datetime value. @@ -566,7 +580,7 @@ def str_to_date(string): return datetime.datetime.strptime(string, DATE_FORMAT) if string else None -def val_to_space_sep_str(vallist): +def val_to_space_sep_str(vallist: list[str]) -> str: """ Converts a list of values into a string of space separated values :param vallist: the list of values to to convert into string @@ -581,7 +595,7 @@ def val_to_space_sep_str(vallist): return '' -def date_to_str(date): +def date_to_str(date: datetime.datetime) -> Optional[str]: """ Converts a datetime value to the corresponding RFC-1123 string. :param date: the datetime value to convert. @@ -611,19 +625,19 @@ def default(self, obj): # pylint: disable=E0202 return json.JSONEncoder.default(self, obj) -def render_json(**data): +def render_json(**data: Any) -> str: """ JSON render function """ return json.dumps(data, cls=APIEncoder) -def render_json_list(list_): +def render_json_list(list_) -> str: """ JSON render function for list """ return json.dumps(list_, cls=APIEncoder) -def datetime_parser(dct): +def datetime_parser(dct: dict[Any, Any]) -> dict[Any, Any]: """ datetime parser """ for k, v in list(dct.items()): @@ -635,17 +649,17 @@ def datetime_parser(dct): return dct -def parse_response(data): +def parse_response(data: str | bytes | bytearray) -> Any: """ JSON render function """ - if hasattr(data, 'decode'): + if isinstance(data, (bytes, bytearray)): data = data.decode('utf-8') return json.loads(data, object_hook=datetime_parser) -def execute(cmd) -> tuple[int, str, str]: +def execute(cmd: str) -> tuple[int, str, str]: """ Executes a command in a subprocess. Returns a tuple of (exitcode, out, err), where out is the string output @@ -667,17 +681,17 @@ def execute(cmd) -> tuple[int, str, str]: return exitcode, out.decode(encoding='utf-8'), err.decode(encoding='utf-8') -def rse_supported_protocol_operations(): +def rse_supported_protocol_operations() -> list[str]: """ Returns a list with operations supported by all RSE protocols.""" return ['read', 'write', 'delete', 'third_party_copy_read', 'third_party_copy_write'] -def rse_supported_protocol_domains(): - """ Returns a list with all supoorted RSE protocol domains.""" +def rse_supported_protocol_domains() -> list[str]: + """ Returns a list with all supported RSE protocol domains.""" return ['lan', 'wan'] -def grouper(iterable, n, fillvalue=None): +def grouper(iterable: Iterable[Any], n: int, fillvalue: Optional[object] = None) -> zip_longest: """ Collect data into fixed-length chunks or blocks """ # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx args = [iter(iterable)] * n @@ -700,7 +714,7 @@ def chunks(iterable, n): yield chunk -def dict_chunks(dict_, n): +def dict_chunks(dict_: dict[Any, Any], n: int) -> Iterator[dict[Any, Any]]: """ Iterate over the dictionary in groups of the requested size """ @@ -709,13 +723,13 @@ def dict_chunks(dict_, n): yield {k: dict_[k] for k in itertools.islice(it, n)} -def my_key_generator(namespace, fn, **kw): +def my_key_generator(namespace: str, fn: Callable, **kw) -> Callable[..., str]: """ - Customyzed key generator for dogpile + Customized key generator for dogpile """ fname = fn.__name__ - def generate_key(*arg, **kw): + def generate_key(*arg, **kw) -> str: return namespace + "_" + fname + "_".join(str(s) for s in filter(None, arg)) return generate_key @@ -850,7 +864,7 @@ def construct_surl_DQ2(dsn: str, scope: str, filename: str) -> str: return '/%s/%s/%s/%s/%s' % (project, dataset_type, tag, stripped_dsn, filename) @staticmethod - def construct_surl_T0(dsn: str, scope: str, filename: str) -> str: + def construct_surl_T0(dsn: str, scope: str, filename: str) -> Optional[str]: """ Defines relative SURL for new replicas. This method contains Tier0 convention. To be used for non-deterministic sites. @@ -890,7 +904,7 @@ def construct_surl_BelleII(dsn: str, scope: str, filename: str) -> str: SurlAlgorithms._module_init_() -def construct_surl(dsn: str, scope: str, filename: str, naming_convention: str = None) -> str: +def construct_surl(dsn: str, scope: str, filename: str, naming_convention: Optional[str] = None) -> str: """ Applies non-deterministic source url convention to the given replica. use the naming_convention to call the actual function which will do the job. @@ -904,7 +918,7 @@ def construct_surl(dsn: str, scope: str, filename: str, naming_convention: str = return surl_algorithms.construct_surl(dsn, scope, filename, naming_convention) -def clean_surls(surls): +def clean_surls(surls: Iterable[str]) -> list[str]: res = [] for surl in surls: if surl.startswith('srm'): @@ -1083,7 +1097,11 @@ def extract_scope_belleii(did: str, scopes: Optional[Sequence[str]]) -> Sequence ScopeExtractionAlgorithms._module_init_() -def extract_scope(did, scopes=None, default_extract=_DEFAULT_EXTRACT): +def extract_scope( + did: str, + scopes: Optional[Sequence[str]] = None, + default_extract: str = _DEFAULT_EXTRACT +) -> Sequence[str]: scope_extraction_algorithms = ScopeExtractionAlgorithms() extract_scope_convention = config_get('common', 'extract_scope', False, None) or config_get('policy', 'extract_scope', False, None) if extract_scope_convention is None or not ScopeExtractionAlgorithms.supports(extract_scope_convention): @@ -1091,7 +1109,7 @@ def extract_scope(did, scopes=None, default_extract=_DEFAULT_EXTRACT): return scope_extraction_algorithms.extract_scope(did, scopes, extract_scope_convention) -def pid_exists(pid): +def pid_exists(pid: int) -> bool: """ Check whether pid exists in the current process table. UNIX only. @@ -1121,7 +1139,7 @@ def pid_exists(pid): return True -def sizefmt(num, human=True): +def sizefmt(num: int | float, human: bool = True) -> str: """ Print human readable file sizes """ @@ -1141,7 +1159,7 @@ def sizefmt(num, human=True): return 'Inf' -def get_tmp_dir(): +def get_tmp_dir() -> str: """ Get a path where to store temporary files. @@ -1169,7 +1187,7 @@ def get_tmp_dir(): return base_dir -def is_archive(name): +def is_archive(name: str) -> bool: ''' Check if a file name is an archive file or not. @@ -1215,7 +1233,7 @@ def resolve_ip(hostname: str) -> str: return hostname -def detect_client_location(): +def detect_client_location() -> IPDict: """ Normally client IP will be set on the server side (request.remote_addr) Here setting ip on the one seen by the host itself. There is no connection @@ -1271,7 +1289,7 @@ def detect_client_location(): 'longitude': longitude} -def ssh_sign(private_key, message): +def ssh_sign(private_key: str, message: str | bytes) -> str: """ Sign a string message using the private key. @@ -1293,7 +1311,7 @@ def ssh_sign(private_key, message): return base64_encoded -def make_valid_did(lfn_dict): +def make_valid_did(lfn_dict: dict[str, Any]) -> dict[str, Any]: """ When managing information about a LFN (such as in `rucio upload` or the RSE manager's upload), we add the `filename` attribute to record @@ -1313,7 +1331,7 @@ def make_valid_did(lfn_dict): return lfn_copy -def send_trace(trace, trace_endpoint, user_agent, retries=5): +def send_trace(trace: dict[str, Any], trace_endpoint: str, user_agent: str, retries: int = 5) -> int: """ Send the given trace to the trace endpoint @@ -1334,7 +1352,7 @@ def send_trace(trace, trace_endpoint, user_agent, retries=5): return 1 -def add_url_query(url, query): +def add_url_query(url: str, query: dict[str, str]) -> str: """ Add a new dictionary to URL parameters @@ -1350,7 +1368,7 @@ def add_url_query(url, query): return urlunparse(url_parts) -def get_bytes_value_from_string(input_string): +def get_bytes_value_from_string(input_string: str) -> bool | int: """ Get bytes from a string that represents a storage value and unit @@ -1380,7 +1398,7 @@ def get_bytes_value_from_string(input_string): return False -def parse_did_filter_from_string(input_string): +def parse_did_filter_from_string(input_string: str) -> tuple[dict[str, Any], str]: """ Parse DID filter options in format 'length<3,type=all' from string. @@ -1440,7 +1458,12 @@ def parse_did_filter_from_string(input_string): return filters, type_ -def parse_did_filter_from_string_fe(input_string, name='*', type='collection', omit_name=False): +def parse_did_filter_from_string_fe( + input_string: str, + name: str = '*', + type: str = 'collection', + omit_name: bool = False +) -> tuple[list[dict[str, Any]], str]: """ Parse DID filter string for the filter engine (fe). @@ -1548,7 +1571,7 @@ def parse_did_filter_from_string_fe(input_string, name='*', type='collection', o return filters, type -def parse_replicas_from_file(path): +def parse_replicas_from_file(path: "FileDescriptorOrPath") -> Any: """ Parses the output of list_replicas from a json or metalink file into a dictionary. Metalink parsing is tried first and if it fails @@ -1569,7 +1592,7 @@ def parse_replicas_from_file(path): raise MetalinkJsonParsingError(path, xml_err, json_err) -def parse_replicas_from_string(string): +def parse_replicas_from_string(string: str) -> Any: """ Parses the output of list_replicas from a json or metalink string into a dictionary. Metalink parsing is tried first and if it fails @@ -1589,7 +1612,7 @@ def parse_replicas_from_string(string): raise MetalinkJsonParsingError(string, xml_err, json_err) -def parse_replicas_metalink(root): +def parse_replicas_metalink(root: ElementTree.Element) -> list[dict[str, Any]]: """ Transforms the metalink tree into a list of dictionaries where each dictionary describes a file with its replicas. @@ -1646,11 +1669,15 @@ def parse_replicas_metalink(root): return files -def get_thread_with_periodic_running_function(interval, action, graceful_stop): +def get_thread_with_periodic_running_function( + interval: int | float, + action: Callable[..., Any], + graceful_stop: threading.Event +) -> threading.Thread: """ Get a thread where a function runs periodically. - :param interval: Interval in seconds when the action fucntion should run. + :param interval: Interval in seconds when the action function should run. :param action: Function, that should run periodically. :param graceful_stop: Threading event used to check for graceful stop. """ @@ -1663,7 +1690,7 @@ def start(): return t -def run_cmd_process(cmd, timeout=3600): +def run_cmd_process(cmd: str, timeout: int = 3600) -> tuple[int, str]: """ shell command parser with timeout @@ -1704,7 +1731,10 @@ def run_cmd_process(cmd, timeout=3600): return returncode, stdout -def api_update_return_dict(dictionary, session=None): +def api_update_return_dict( + dictionary: dict[str, Any], + session: Optional["Session"] = None +) -> dict[str, Any]: """ Ensure that rse is in a dictionary returned from core @@ -1742,7 +1772,12 @@ def api_update_return_dict(dictionary, session=None): return dictionary -def setup_logger(module_name=None, logger_name=None, logger_level=None, verbose=False): +def setup_logger( + module_name: Optional[str] = None, + logger_name: Optional[str] = None, + logger_level: Optional[int] = None, + verbose: bool = False +) -> logging.Logger: ''' Factory method to set logger with handlers. :param module_name: __name__ of the module that is calling this method @@ -1784,11 +1819,11 @@ def _force_cfg_log_level(cfg_option): logger.setLevel(logger_level) # preferred logger handling - def add_handler(logger): + def add_handler(logger: logging.Logger) -> None: hdlr = logging.StreamHandler() - def emit_decorator(fnc): - def func(*args): + def emit_decorator(fnc: Callable[..., Any]) -> Callable[..., Any]: + def func(*args) -> Callable[..., Any]: if 'RUCIO_LOGGING_FORMAT' not in os.environ: levelno = args[0].levelno format_str = '%(asctime)s\t%(levelname)s\t%(message)s\033[0m' @@ -1821,7 +1856,12 @@ def func(*args): return logger -def daemon_sleep(start_time, sleep_time, graceful_stop, logger=logging.log): +def daemon_sleep( + start_time: float, + sleep_time: float, + graceful_stop: threading.Event, + logger: "LoggerFunction" = logging.log +) -> None: """Sleeps a daemon the time provided by sleep_time""" end_time = time.time() time_diff = end_time - start_time @@ -1830,7 +1870,7 @@ def daemon_sleep(start_time, sleep_time, graceful_stop, logger=logging.log): graceful_stop.wait(sleep_time - time_diff) -def is_client(): +def is_client() -> bool: """" Checks if the function is called from a client or from a server/daemon @@ -1859,7 +1899,7 @@ def is_client(): class retry: """Retry callable object with configuragle number of attempts""" - def __init__(self, func, *args, **kwargs): + def __init__(self, func: Callable[..., Any], *args, **kwargs): ''' :param func: a method that should be executed with retries :param args: parametres of the func @@ -1867,7 +1907,7 @@ def __init__(self, func, *args, **kwargs): ''' self.func, self.args, self.kwargs = func, args, kwargs - def __call__(self, mtries=3, logger=logging.log): + def __call__(self, mtries: int = 3, logger: "LoggerFunction" = logging.log) -> Callable[..., Any]: ''' :param mtries: maximum number of attempts to execute the function :param logger: preferred logger @@ -1893,9 +1933,9 @@ class StoreAndDeprecateWarningAction(argparse.Action): ''' def __init__(self, - option_strings, - new_option_string, - dest, + option_strings: Sequence[str], + new_option_string: str, + dest: str, **kwargs): """ :param option_strings: all possible argument name strings @@ -1910,7 +1950,7 @@ def __init__(self, assert new_option_string in option_strings self.new_option_string = new_option_string - def __call__(self, parser, namespace, values, option_string=None): + def __call__(self, parser, namespace, values, option_string: Optional[str] = None): if option_string and option_string != self.new_option_string: # The logger gets typically initialized after the argument parser # to set the verbosity of the logger. Thus using simple print to console. @@ -1926,12 +1966,12 @@ class StoreTrueAndDeprecateWarningAction(argparse._StoreConstAction): ''' def __init__(self, - option_strings, - new_option_string, - dest, - default=False, - required=False, - help=None): + option_strings: Sequence[str], + new_option_string: str, + dest: str, + default: bool = False, + required: bool = False, + help: Optional[str] = None): """ :param option_strings: all possible argument name strings :param new_option_string: the new option string which replaces the old @@ -1948,7 +1988,7 @@ def __init__(self, assert new_option_string in option_strings self.new_option_string = new_option_string - def __call__(self, parser, namespace, values, option_string=None): + def __call__(self, parser, namespace, values, option_string: Optional[str] = None): super(StoreTrueAndDeprecateWarningAction, self).__call__(parser, namespace, values, option_string=option_string) if option_string and option_string != self.new_option_string: # The logger gets typically initialized after the argument parser @@ -1967,7 +2007,7 @@ class PriorityQueue: [1] https://en.wikipedia.org/wiki/Heap_(data_structure) """ class ContainerSlot: - def __init__(self, position, priority): + def __init__(self, position: int, priority: int): self.pos = position self.prio = priority @@ -2059,7 +2099,7 @@ def _priority_increased(self, item): return heap_changed -def check_policy_package_version(package): +def check_policy_package_version(package: str) -> None: import importlib from rucio.version import version_string @@ -2097,7 +2137,12 @@ class Availability: write = None delete = None - def __init__(self, read=None, write=None, delete=None): + def __init__( + self, + read: Optional[bool] = None, + write: Optional[bool] = None, + delete: Optional[bool] = None + ): self.read = read self.write = write self.delete = delete