From 99153e14e4c9c9848b252b36df9c3adb0b03271d Mon Sep 17 00:00:00 2001 From: rdimaio Date: Mon, 25 Mar 2024 18:31:17 +0100 Subject: [PATCH 1/4] Client: refactor constructor to improve static type checking; #6588 --- lib/rucio/client/baseclient.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/lib/rucio/client/baseclient.py b/lib/rucio/client/baseclient.py index 8165af20bb..2bfdb03dc8 100644 --- a/lib/rucio/client/baseclient.py +++ b/lib/rucio/client/baseclient.py @@ -107,8 +107,6 @@ def __init__(self, :param logger: Logger object to use. If None, use the default LOG created by the module """ - self.host = rucio_host - self.auth_host = auth_host self.logger = logger self.session = Session() self.user_agent = "%s/%s" % (user_agent, version.version_string()) # e.g. "rucio-clients/0.2.13" @@ -117,9 +115,17 @@ def __init__(self, if self.script_id == '': # Python interpreter used self.script_id = 'python' try: - if self.host is None: + if rucio_host is not None: + self.host = rucio_host + else: self.host = config_get('client', 'rucio_host') - if self.auth_host is None: + except (NoOptionError, NoSectionError) as error: + raise MissingClientParameter('Section client and Option \'%s\' cannot be found in config file' % error.args[0]) + + try: + if auth_host is not None: + self.auth_host = auth_host + else: self.auth_host = config_get('client', 'auth_host') except (NoOptionError, NoSectionError) as error: raise MissingClientParameter('Section client and Option \'%s\' cannot be found in config file' % error.args[0]) @@ -318,9 +324,12 @@ def _get_exception(self, headers: dict[str, str], status_code: Optional[int] = N :return: A rucio exception class and an error string. """ - try: - data = parse_response(data) - except ValueError: + if data is not None: + try: + data = parse_response(data) + except ValueError: + data = {} + else: data = {} exc_cls = 'RucioException' From 265b48783c2d2bedd020e028fdbaf5631fc84c37 Mon Sep 17 00:00:00 2001 From: rdimaio Date: Mon, 25 Mar 2024 18:31:54 +0100 Subject: [PATCH 2/4] Testing: Add type annotations to utils.py and related function calls; #6588 --- lib/rucio/client/downloadclient.py | 2 +- lib/rucio/common/types.py | 8 + lib/rucio/common/utils.py | 219 +++++++++++------- .../core/did_meta_plugins/filter_engine.py | 6 +- lib/rucio/core/oidc.py | 2 +- 5 files changed, 145 insertions(+), 92 deletions(-) diff --git a/lib/rucio/client/downloadclient.py b/lib/rucio/client/downloadclient.py index bb88eeb4e4..ff49dd5156 100644 --- a/lib/rucio/client/downloadclient.py +++ b/lib/rucio/client/downloadclient.py @@ -1231,7 +1231,7 @@ def _resolve_and_merge_input_items(self, input_items, sort=None): resolve_parents=True, nrandom=nrandom, metalink=True) - file_items = parse_replicas_from_string(metalink_str) + file_items = parse_replicas_from_string(metalink_str) # type: ignore for file in file_items: if impl: file['impl'] = impl diff --git a/lib/rucio/common/types.py b/lib/rucio/common/types.py index 441ab396b3..f533e14b6a 100644 --- a/lib/rucio/common/types.py +++ b/lib/rucio/common/types.py @@ -205,3 +205,11 @@ class TokenValidationDict(TypedDict): lifetime: datetime audience: Optional[str] authz_scope: Optional[str] + + +class IPDict(TypedDict): + ip: str + fqdn: str + site: str + latitude: Optional[float] + longitude: Optional[float] diff --git a/lib/rucio/common/utils.py b/lib/rucio/common/utils.py index 8f85955d88..bd25b42d48 100644 --- a/lib/rucio/common/utils.py +++ b/lib/rucio/common/utils.py @@ -37,12 +37,12 @@ import time import zlib from collections import OrderedDict -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence from enum import Enum from functools import partial, wraps from io import StringIO from itertools import zip_longest -from typing import TYPE_CHECKING, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse from uuid import uuid4 as uuid from xml.etree import ElementTree @@ -65,6 +65,10 @@ if TYPE_CHECKING: T = TypeVar('T') + from _typeshed import FileDescriptorOrPath + from sqlalchemy.orm import Session + + from rucio.common.types import IPDict, LoggerFunction # HTTP code dictionary. Not complete. Can be extended if needed. @@ -97,7 +101,7 @@ DATE_FORMAT = '%a, %d %b %Y %H:%M:%S UTC' -def invert_dict(d): +def invert_dict(d: dict[Any, Any]) -> dict[Any, Any]: """ Invert the dictionary. CAUTION: this function is not deterministic unless the input dictionary is one-to-one mapping. @@ -108,11 +112,11 @@ def invert_dict(d): return {value: key for key, value in d.items()} -def dids_as_dicts(did_list): +def dids_as_dicts(did_list: Iterable[Union[str, dict[str, str]]]) -> list[dict[str, str]]: """ Converts list of DIDs to list of dictionaries - :param did_list: list of DIDs as either "scope:name" or {"scope":"scope", "name","name"} - :returns: list of dictionaries {"scope":"scope", "name","name"} + :param did_list: list of DIDs as either "scope:name" or {"scope":"scope", "name":"name"} + :returns: list of dictionaries {"scope":"scope", "name":"name"} """ out = [] for did in did_list: @@ -128,7 +132,12 @@ def dids_as_dicts(did_list): return out -def build_url(url, path=None, params=None, doseq=False): +def build_url( + url: str, + path: Optional[str] = None, + params: Optional[Union[str, dict[Any, Any], list[tuple[Any, Any]]]] = None, + doseq: bool = False +) -> str: """ utitily function to build an url for requests to the rucio system. @@ -147,7 +156,13 @@ def build_url(url, path=None, params=None, doseq=False): return complete_url -def all_oidc_req_claims_present(scope, audience, required_scope, required_audience, separator=" "): +def all_oidc_req_claims_present( + scope: Optional[Union[str, list[str]]], + audience: Optional[Union[str, list[str]]], + required_scope: Optional[Union[str, list[str]]], + required_audience: Optional[Union[str, list[str]]], + sepatator: str = " " +) -> bool: """ Checks if both of the following statements are true: - all items in required_scope are present in scope string @@ -206,11 +221,11 @@ def all_oidc_req_claims_present(scope, audience, required_scope, required_audien return False -def generate_uuid(): +def generate_uuid() -> str: return str(uuid()).replace('-', '').lower() -def generate_uuid_bytes(): +def generate_uuid_bytes() -> bytes: return uuid().bytes @@ -221,7 +236,7 @@ def generate_uuid_bytes(): CHECKSUM_KEY = 'supported_checksums' -def is_checksum_valid(checksum_name): +def is_checksum_valid(checksum_name: str) -> bool: """ A simple function to check whether a checksum algorithm is supported. Relies on GLOBALLY_SUPPORTED_CHECKSUMS to allow for expandability. @@ -233,20 +248,19 @@ def is_checksum_valid(checksum_name): return checksum_name in GLOBALLY_SUPPORTED_CHECKSUMS -def set_preferred_checksum(checksum_name): +def set_preferred_checksum(checksum_name: str) -> None: """ - A simple function to check whether a checksum algorithm is supported. - Relies on GLOBALLY_SUPPORTED_CHECKSUMS to allow for expandability. + If the input checksum name is valid, + set it as PREFERRED_CHECKSUM. :param checksum_name: The name of the checksum to be verified. - :returns: True if checksum_name is in GLOBALLY_SUPPORTED_CHECKSUMS list, False otherwise. """ if is_checksum_valid(checksum_name): global PREFERRED_CHECKSUM PREFERRED_CHECKSUM = checksum_name -def adler32(file): +def adler32(file: "FileDescriptorOrPath") -> str: """ An Adler-32 checksum is obtained by calculating two 16-bit checksums A and B and concatenating their bits into a 32-bit integer. A is the sum of all bytes in the @@ -293,7 +307,7 @@ def adler32(file): CHECKSUM_ALGO_DICT['adler32'] = adler32 -def md5(file): +def md5(file: "FileDescriptorOrPath") -> str: """ Runs the MD5 algorithm (RFC-1321) on the binary content of the file named file and returns the hexadecimal digest @@ -313,7 +327,7 @@ def md5(file): CHECKSUM_ALGO_DICT['md5'] = md5 -def sha256(file): +def sha256(file: "FileDescriptorOrPath") -> str: """ Runs the SHA256 algorithm on the binary content of the file named file and returns the hexadecimal digest @@ -330,7 +344,7 @@ def sha256(file): CHECKSUM_ALGO_DICT['sha256'] = sha256 -def crc32(file): +def crc32(file: "FileDescriptorOrPath") -> str: """ Runs the CRC32 algorithm on the binary content of the file named file and returns the hexadecimal digest @@ -346,7 +360,7 @@ def crc32(file): CHECKSUM_ALGO_DICT['crc32'] = crc32 -def _next_pow2(num): +def _next_pow2(num: int) -> int: if not num: return 0 return math.ceil(math.log2(num)) @@ -392,7 +406,7 @@ def bittorrent_v2_piece_length(file_size: int) -> int: return 2 ** _bittorrent_v2_piece_length_pow2(file_size) -def bittorrent_v2_merkle_sha256(file) -> tuple[bytes, bytes, int]: +def bittorrent_v2_merkle_sha256(file: "FileDescriptorOrPath") -> tuple[bytes, bytes, int]: """ Compute the .torrent v2 hash tree for the given file. (http://www.bittorrent.org/beps/bep_0052.html) @@ -483,7 +497,7 @@ def _merkle_root(leafs: list[bytes], nb_levels: int, padding: bytes) -> bytes: return pieces_root, pieces_layers, piece_length -def merkle_sha256(file) -> str: +def merkle_sha256(file: "FileDescriptorOrPath") -> str: """ The root of the sha256 merkle hash tree with leaf size of 16 KiB. """ @@ -494,7 +508,7 @@ def merkle_sha256(file) -> str: CHECKSUM_ALGO_DICT['merkle_sha256'] = merkle_sha256 -def bencode(obj) -> bytes: +def bencode(obj: Union[int, bytes, str, list, dict[bytes, Any]]) -> bytes: """ Copied from the reference implementation of v2 bittorrent: http://bittorrent.org/beps/bep_0052_torrent_creator.py @@ -515,7 +529,7 @@ def bencode(obj) -> bytes: return b"d" + b"".join(map(bencode, itertools.chain(*items))) + b"e" else: raise ValueError("dict keys should be bytes " + str(obj.keys())) - raise ValueError("Allowed types: int, bytes, list, dict; not %s", type(obj)) + raise ValueError("Allowed types: int, bytes, str, list, dict; not %s", type(obj)) def construct_torrent( @@ -558,7 +572,7 @@ def construct_torrent( return torrent_id, torrent -def str_to_date(string): +def str_to_date(string: str) -> Optional[datetime.datetime]: """ Converts a RFC-1123 string to the corresponding datetime value. :param string: the RFC-1123 string to convert to datetime value. @@ -566,7 +580,7 @@ def str_to_date(string): return datetime.datetime.strptime(string, DATE_FORMAT) if string else None -def val_to_space_sep_str(vallist): +def val_to_space_sep_str(vallist: list[str]) -> str: """ Converts a list of values into a string of space separated values :param vallist: the list of values to to convert into string @@ -581,7 +595,7 @@ def val_to_space_sep_str(vallist): return '' -def date_to_str(date): +def date_to_str(date: datetime.datetime) -> Optional[str]: """ Converts a datetime value to the corresponding RFC-1123 string. :param date: the datetime value to convert. @@ -611,19 +625,19 @@ def default(self, obj): # pylint: disable=E0202 return json.JSONEncoder.default(self, obj) -def render_json(**data): +def render_json(**data: Any) -> str: """ JSON render function """ return json.dumps(data, cls=APIEncoder) -def render_json_list(list_): +def render_json_list(list_) -> str: """ JSON render function for list """ return json.dumps(list_, cls=APIEncoder) -def datetime_parser(dct): +def datetime_parser(dct: dict[Any, Any]) -> dict[Any, Any]: """ datetime parser """ for k, v in list(dct.items()): @@ -635,17 +649,17 @@ def datetime_parser(dct): return dct -def parse_response(data): +def parse_response(data: Union[str, bytes, bytearray]) -> Any: """ JSON render function """ - if hasattr(data, 'decode'): + if isinstance(data, (bytes, bytearray)): data = data.decode('utf-8') return json.loads(data, object_hook=datetime_parser) -def execute(cmd) -> tuple[int, str, str]: +def execute(cmd: str) -> tuple[int, str, str]: """ Executes a command in a subprocess. Returns a tuple of (exitcode, out, err), where out is the string output @@ -667,17 +681,17 @@ def execute(cmd) -> tuple[int, str, str]: return exitcode, out.decode(encoding='utf-8'), err.decode(encoding='utf-8') -def rse_supported_protocol_operations(): +def rse_supported_protocol_operations() -> list[str]: """ Returns a list with operations supported by all RSE protocols.""" return ['read', 'write', 'delete', 'third_party_copy_read', 'third_party_copy_write'] -def rse_supported_protocol_domains(): +def rse_supported_protocol_domains() -> list[str]: """ Returns a list with all supported RSE protocol domains.""" return ['lan', 'wan'] -def grouper(iterable, n, fillvalue=None): +def grouper(iterable: Iterable[Any], n: int, fillvalue: Optional[object] = None) -> zip_longest: """ Collect data into fixed-length chunks or blocks """ # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx args = [iter(iterable)] * n @@ -700,7 +714,7 @@ def chunks(iterable, n): yield chunk -def dict_chunks(dict_, n): +def dict_chunks(dict_: dict[Any, Any], n: int) -> Iterator[dict[Any, Any]]: """ Iterate over the dictionary in groups of the requested size """ @@ -709,13 +723,13 @@ def dict_chunks(dict_, n): yield {k: dict_[k] for k in itertools.islice(it, n)} -def my_key_generator(namespace, fn, **kw): +def my_key_generator(namespace: str, fn: Callable, **kw) -> Callable[..., str]: """ - Customyzed key generator for dogpile + Customized key generator for dogpile """ fname = fn.__name__ - def generate_key(*arg, **kw): + def generate_key(*arg, **kw) -> str: return namespace + "_" + fname + "_".join(str(s) for s in filter(None, arg)) return generate_key @@ -767,7 +781,7 @@ def get_algorithm(cls: type[SurlAlgorithmsT], naming_convention: str) -> Callabl return super()._get_one_algorithm(cls._algorithm_type, naming_convention) @classmethod - def register(cls: type[SurlAlgorithmsT], name: str, fn_construct_surl: Callable[[str, str, str], str]) -> None: + def register(cls: Type[SurlAlgorithmsT], name: str, fn_construct_surl: Callable[[str, str, str], Optional[str]]) -> None: """ Register a new SURL algorithm """ @@ -850,7 +864,7 @@ def construct_surl_DQ2(dsn: str, scope: str, filename: str) -> str: return '/%s/%s/%s/%s/%s' % (project, dataset_type, tag, stripped_dsn, filename) @staticmethod - def construct_surl_T0(dsn: str, scope: str, filename: str) -> str: + def construct_surl_T0(dsn: str, scope: str, filename: str) -> Optional[str]: """ Defines relative SURL for new replicas. This method contains Tier0 convention. To be used for non-deterministic sites. @@ -890,7 +904,7 @@ def construct_surl_BelleII(dsn: str, scope: str, filename: str) -> str: SurlAlgorithms._module_init_() -def construct_surl(dsn: str, scope: str, filename: str, naming_convention: str = None) -> str: +def construct_surl(dsn: str, scope: str, filename: str, naming_convention: Optional[str] = None) -> str: """ Applies non-deterministic source url convention to the given replica. use the naming_convention to call the actual function which will do the job. @@ -904,7 +918,7 @@ def construct_surl(dsn: str, scope: str, filename: str, naming_convention: str = return surl_algorithms.construct_surl(dsn, scope, filename, naming_convention) -def clean_surls(surls): +def clean_surls(surls: Iterable[str]) -> list[str]: res = [] for surl in surls: if surl.startswith('srm'): @@ -1083,7 +1097,11 @@ def extract_scope_belleii(did: str, scopes: Optional[Sequence[str]]) -> Sequence ScopeExtractionAlgorithms._module_init_() -def extract_scope(did, scopes=None, default_extract=_DEFAULT_EXTRACT): +def extract_scope( + did: str, + scopes: Optional[Sequence[str]] = None, + default_extract: str = _DEFAULT_EXTRACT +) -> Sequence[str]: scope_extraction_algorithms = ScopeExtractionAlgorithms() extract_scope_convention = config_get('common', 'extract_scope', False, None) or config_get('policy', 'extract_scope', False, None) if extract_scope_convention is None or not ScopeExtractionAlgorithms.supports(extract_scope_convention): @@ -1091,7 +1109,7 @@ def extract_scope(did, scopes=None, default_extract=_DEFAULT_EXTRACT): return scope_extraction_algorithms.extract_scope(did, scopes, extract_scope_convention) -def pid_exists(pid): +def pid_exists(pid: int) -> bool: """ Check whether pid exists in the current process table. UNIX only. @@ -1121,7 +1139,7 @@ def pid_exists(pid): return True -def sizefmt(num, human=True): +def sizefmt(num: Union[int, float], human: bool = True) -> str: """ Print human readable file sizes """ @@ -1141,7 +1159,7 @@ def sizefmt(num, human=True): return 'Inf' -def get_tmp_dir(): +def get_tmp_dir() -> str: """ Get a path where to store temporary files. @@ -1169,7 +1187,7 @@ def get_tmp_dir(): return base_dir -def is_archive(name): +def is_archive(name: str) -> bool: ''' Check if a file name is an archive file or not. @@ -1215,7 +1233,7 @@ def resolve_ip(hostname: str) -> str: return hostname -def detect_client_location(): +def detect_client_location() -> "IPDict": """ Normally client IP will be set on the server side (request.remote_addr) Here setting ip on the one seen by the host itself. There is no connection @@ -1292,7 +1310,7 @@ def ssh_sign(private_key: str, message: str) -> str: return base64_encoded -def make_valid_did(lfn_dict): +def make_valid_did(lfn_dict: dict[str, Any]) -> dict[str, Any]: """ When managing information about a LFN (such as in `rucio upload` or the RSE manager's upload), we add the `filename` attribute to record @@ -1312,7 +1330,7 @@ def make_valid_did(lfn_dict): return lfn_copy -def send_trace(trace, trace_endpoint, user_agent, retries=5): +def send_trace(trace: dict[str, Any], trace_endpoint: str, user_agent: str, retries: int = 5) -> int: """ Send the given trace to the trace endpoint @@ -1333,7 +1351,7 @@ def send_trace(trace, trace_endpoint, user_agent, retries=5): return 1 -def add_url_query(url, query): +def add_url_query(url: str, query: dict[str, str]) -> str: """ Add a new dictionary to URL parameters @@ -1349,7 +1367,7 @@ def add_url_query(url, query): return urlunparse(url_parts) -def get_bytes_value_from_string(input_string): +def get_bytes_value_from_string(input_string: str) -> Union[bool, int]: """ Get bytes from a string that represents a storage value and unit @@ -1379,7 +1397,7 @@ def get_bytes_value_from_string(input_string): return False -def parse_did_filter_from_string(input_string): +def parse_did_filter_from_string(input_string: str) -> tuple[dict[str, Any], str]: """ Parse DID filter options in format 'length<3,type=all' from string. @@ -1416,13 +1434,13 @@ def parse_did_filter_from_string(input_string): value = datetime.datetime.strptime(value, '%Y-%m-%dT%H:%M:%S.%fZ') if key == 'type': - if value.upper() in ['ALL', 'COLLECTION', 'CONTAINER', 'DATASET', 'FILE']: - type_ = value.lower() + if value.upper() in ['ALL', 'COLLECTION', 'CONTAINER', 'DATASET', 'FILE']: # type: ignore + type_ = value.lower() # type: ignore else: raise InvalidType('{0} is not a valid type. Valid types are {1}'.format(value, ['ALL', 'COLLECTION', 'CONTAINER', 'DATASET', 'FILE'])) elif key in ('length.gt', 'length.lt', 'length.gte', 'length.lte', 'length'): try: - value = int(value) + value = int(value) # type: ignore filters[key] = value except ValueError: raise ValueError('Length has to be an integer value.') @@ -1439,7 +1457,12 @@ def parse_did_filter_from_string(input_string): return filters, type_ -def parse_did_filter_from_string_fe(input_string, name='*', type='collection', omit_name=False): +def parse_did_filter_from_string_fe( + input_string: str, + name: str = '*', + type: str = 'collection', + omit_name: bool = False +) -> tuple[list[dict[str, Any]], str]: """ Parse DID filter string for the filter engine (fe). @@ -1547,7 +1570,7 @@ def parse_did_filter_from_string_fe(input_string, name='*', type='collection', o return filters, type -def parse_replicas_from_file(path): +def parse_replicas_from_file(path: "FileDescriptorOrPath") -> Any: """ Parses the output of list_replicas from a json or metalink file into a dictionary. Metalink parsing is tried first and if it fails @@ -1568,7 +1591,7 @@ def parse_replicas_from_file(path): raise MetalinkJsonParsingError(path, xml_err, json_err) -def parse_replicas_from_string(string): +def parse_replicas_from_string(string: str) -> Any: """ Parses the output of list_replicas from a json or metalink string into a dictionary. Metalink parsing is tried first and if it fails @@ -1588,7 +1611,7 @@ def parse_replicas_from_string(string): raise MetalinkJsonParsingError(string, xml_err, json_err) -def parse_replicas_metalink(root): +def parse_replicas_metalink(root: ElementTree.Element) -> list[dict[str, Any]]: """ Transforms the metalink tree into a list of dictionaries where each dictionary describes a file with its replicas. @@ -1645,7 +1668,11 @@ def parse_replicas_metalink(root): return files -def get_thread_with_periodic_running_function(interval, action, graceful_stop): +def get_thread_with_periodic_running_function( + interval: Union[int, float], + action: Callable[..., Any], + graceful_stop: threading.Event +) -> threading.Thread: """ Get a thread where a function runs periodically. @@ -1662,7 +1689,7 @@ def start(): return t -def run_cmd_process(cmd, timeout=3600): +def run_cmd_process(cmd: str, timeout: int = 3600) -> tuple[int, str]: """ shell command parser with timeout @@ -1703,7 +1730,10 @@ def run_cmd_process(cmd, timeout=3600): return returncode, stdout -def gateway_update_return_dict(dictionary, session=None): +def gateway_update_return_dict( + dictionary: dict[str, Any], + session: Optional["Session"] = None +) -> dict[str, Any]: """ Ensure that rse is in a dictionary returned from core @@ -1741,7 +1771,12 @@ def gateway_update_return_dict(dictionary, session=None): return dictionary -def setup_logger(module_name=None, logger_name=None, logger_level=None, verbose=False): +def setup_logger( + module_name: Optional[str] = None, + logger_name: Optional[str] = None, + logger_level: Optional[int] = None, + verbose: bool = False +) -> logging.Logger: ''' Factory method to set logger with handlers. :param module_name: __name__ of the module that is calling this method @@ -1783,11 +1818,11 @@ def _force_cfg_log_level(cfg_option): logger.setLevel(logger_level) # preferred logger handling - def add_handler(logger): + def add_handler(logger: logging.Logger) -> None: hdlr = logging.StreamHandler() - def emit_decorator(fnc): - def func(*args): + def emit_decorator(fnc: Callable[..., Any]) -> Callable[..., Any]: + def func(*args) -> Callable[..., Any]: if 'RUCIO_LOGGING_FORMAT' not in os.environ: levelno = args[0].levelno format_str = '%(asctime)s\t%(levelname)s\t%(message)s\033[0m' @@ -1820,7 +1855,12 @@ def func(*args): return logger -def daemon_sleep(start_time, sleep_time, graceful_stop, logger=logging.log): +def daemon_sleep( + start_time: float, + sleep_time: float, + graceful_stop: threading.Event, + logger: "LoggerFunction" = logging.log +) -> None: """Sleeps a daemon the time provided by sleep_time""" end_time = time.time() time_diff = end_time - start_time @@ -1829,7 +1869,7 @@ def daemon_sleep(start_time, sleep_time, graceful_stop, logger=logging.log): graceful_stop.wait(sleep_time - time_diff) -def is_client(): +def is_client() -> bool: """" Checks if the function is called from a client or from a server/daemon @@ -1858,7 +1898,7 @@ def is_client(): class retry: """Retry callable object with configuragle number of attempts""" - def __init__(self, func, *args, **kwargs): + def __init__(self, func: Callable[..., Any], *args, **kwargs): ''' :param func: a method that should be executed with retries :param args: parameters of the func @@ -1866,7 +1906,7 @@ def __init__(self, func, *args, **kwargs): ''' self.func, self.args, self.kwargs = func, args, kwargs - def __call__(self, mtries=3, logger=logging.log): + def __call__(self, mtries: int = 3, logger: "LoggerFunction" = logging.log) -> Callable[..., Any]: ''' :param mtries: maximum number of attempts to execute the function :param logger: preferred logger @@ -1892,9 +1932,9 @@ class StoreAndDeprecateWarningAction(argparse.Action): ''' def __init__(self, - option_strings, - new_option_string, - dest, + option_strings: Sequence[str], + new_option_string: str, + dest: str, **kwargs): """ :param option_strings: all possible argument name strings @@ -1909,7 +1949,7 @@ def __init__(self, assert new_option_string in option_strings self.new_option_string = new_option_string - def __call__(self, parser, namespace, values, option_string=None): + def __call__(self, parser, namespace, values, option_string: Optional[str] = None): if option_string and option_string != self.new_option_string: # The logger gets typically initialized after the argument parser # to set the verbosity of the logger. Thus using simple print to console. @@ -1925,12 +1965,12 @@ class StoreTrueAndDeprecateWarningAction(argparse._StoreConstAction): ''' def __init__(self, - option_strings, - new_option_string, - dest, - default=False, - required=False, - help=None): + option_strings: Sequence[str], + new_option_string: str, + dest: str, + default: bool = False, + required: bool = False, + help: Optional[str] = None): """ :param option_strings: all possible argument name strings :param new_option_string: the new option string which replaces the old @@ -1947,7 +1987,7 @@ def __init__(self, assert new_option_string in option_strings self.new_option_string = new_option_string - def __call__(self, parser, namespace, values, option_string=None): + def __call__(self, parser, namespace, values, option_string: Optional[str] = None): super(StoreTrueAndDeprecateWarningAction, self).__call__(parser, namespace, values, option_string=option_string) if option_string and option_string != self.new_option_string: # The logger gets typically initialized after the argument parser @@ -1966,7 +2006,7 @@ class PriorityQueue: [1] https://en.wikipedia.org/wiki/Heap_(data_structure) """ class ContainerSlot: - def __init__(self, position, priority): + def __init__(self, position: int, priority: int): self.pos = position self.prio = priority @@ -2058,7 +2098,7 @@ def _priority_increased(self, item): return heap_changed -def check_policy_package_version(package): +def check_policy_package_version(package: str) -> None: import importlib from rucio.version import version_string @@ -2096,7 +2136,12 @@ class Availability: write = None delete = None - def __init__(self, read=None, write=None, delete=None): + def __init__( + self, + read: Optional[bool] = None, + write: Optional[bool] = None, + delete: Optional[bool] = None + ): self.read = read self.write = write self.delete = delete diff --git a/lib/rucio/core/did_meta_plugins/filter_engine.py b/lib/rucio/core/did_meta_plugins/filter_engine.py index 03161ace0f..3983623110 100644 --- a/lib/rucio/core/did_meta_plugins/filter_engine.py +++ b/lib/rucio/core/did_meta_plugins/filter_engine.py @@ -287,7 +287,7 @@ def create_mongo_query( # Add additional filters, applied as AND clauses to each OR group. for or_group in self._filters: for filter in additional_filters: - or_group.append(list(filter)) + or_group.append(list(filter)) # type: ignore or_expressions = [] for or_group in self._filters: @@ -349,7 +349,7 @@ def create_postgres_query( # Add additional filters, applied as AND clauses to each OR group. for or_group in self._filters: for _filter in additional_filters: - or_group.append(list(_filter)) + or_group.append(list(_filter)) # type: ignore or_expressions = [] for or_group in self._filters: @@ -439,7 +439,7 @@ def create_sqla_query( # Add additional filters, applied as AND clauses to each OR group. for or_group in self._filters: for _filter in additional_filters: - or_group.append(list(_filter)) + or_group.append(list(_filter)) # type: ignore or_expressions = [] for or_group in self._filters: diff --git a/lib/rucio/core/oidc.py b/lib/rucio/core/oidc.py index df56fa3790..ca4c98c40c 100644 --- a/lib/rucio/core/oidc.py +++ b/lib/rucio/core/oidc.py @@ -1274,7 +1274,7 @@ def __get_keyvalues_from_claims(token: str, keys=None): for key in keys: value = '' if key in claims: - value = val_to_space_sep_str(claims[key]) + value = val_to_space_sep_str(claims[key]) # type: ignore resdict[key] = value return resdict except Exception as error: From b0701070faae3b6c441548b57ef70bd1c52eef5e Mon Sep 17 00:00:00 2001 From: rdimaio Date: Mon, 25 Mar 2024 16:33:36 +0100 Subject: [PATCH 3/4] setup_logger: check module_name is not None prior to regex match --- lib/rucio/common/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/rucio/common/utils.py b/lib/rucio/common/utils.py index bd25b42d48..54bf52240c 100644 --- a/lib/rucio/common/utils.py +++ b/lib/rucio/common/utils.py @@ -1785,10 +1785,10 @@ def setup_logger( :param verbose: verbose option set in bin/rucio ''' # helper method for cfg check - def _force_cfg_log_level(cfg_option): + def _force_cfg_log_level(cfg_option: str) -> bool: cfg_forced_modules = config_get('logging', cfg_option, raise_exception=False, default=None, clean_cached=True, check_config_table=False) - if cfg_forced_modules: + if cfg_forced_modules and module_name is not None: if re.match(str(cfg_forced_modules), module_name): return True return False From 2a1e9aaeef9102246c0bd38e745fc1b1a006783a Mon Sep 17 00:00:00 2001 From: rdimaio Date: Tue, 26 Mar 2024 14:46:35 +0100 Subject: [PATCH 4/4] Testing: Refactor FilterEngine and add type annotations; #6588 --- lib/rucio/common/utils.py | 6 +- .../core/did_meta_plugins/filter_engine.py | 94 +++++++++++-------- 2 files changed, 58 insertions(+), 42 deletions(-) diff --git a/lib/rucio/common/utils.py b/lib/rucio/common/utils.py index 54bf52240c..a1e0ac263d 100644 --- a/lib/rucio/common/utils.py +++ b/lib/rucio/common/utils.py @@ -42,7 +42,7 @@ from functools import partial, wraps from io import StringIO from itertools import zip_longest -from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse from uuid import uuid4 as uuid from xml.etree import ElementTree @@ -161,7 +161,7 @@ def all_oidc_req_claims_present( audience: Optional[Union[str, list[str]]], required_scope: Optional[Union[str, list[str]]], required_audience: Optional[Union[str, list[str]]], - sepatator: str = " " + separator: str = " " ) -> bool: """ Checks if both of the following statements are true: @@ -781,7 +781,7 @@ def get_algorithm(cls: type[SurlAlgorithmsT], naming_convention: str) -> Callabl return super()._get_one_algorithm(cls._algorithm_type, naming_convention) @classmethod - def register(cls: Type[SurlAlgorithmsT], name: str, fn_construct_surl: Callable[[str, str, str], Optional[str]]) -> None: + def register(cls: type[SurlAlgorithmsT], name: str, fn_construct_surl: Callable[[str, str, str], Optional[str]]) -> None: """ Register a new SURL algorithm """ diff --git a/lib/rucio/core/did_meta_plugins/filter_engine.py b/lib/rucio/core/did_meta_plugins/filter_engine.py index 3983623110..97ecfebc42 100644 --- a/lib/rucio/core/did_meta_plugins/filter_engine.py +++ b/lib/rucio/core/did_meta_plugins/filter_engine.py @@ -15,10 +15,9 @@ import ast import fnmatch import operator -from collections.abc import Callable, Iterable from datetime import date, datetime, timedelta from importlib import import_module -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union import sqlalchemy from sqlalchemy import and_, cast, or_ @@ -31,8 +30,14 @@ from rucio.db.sqla.session import read_session if TYPE_CHECKING: + from collections.abc import Callable, Iterable + from sqlalchemy.orm import Session + from rucio.db.sqla.models import ModelBase + + KeyType = TypeVar("KeyType", str, InstrumentedAttribute) + FilterTuple = tuple[KeyType, Callable[[object, object], Any], Union[bool, datetime, float, str]] # lookup table converting keyword suffixes to pythonic operators. OPERATORS_CONVERSION_LUT = { @@ -78,25 +83,30 @@ class FilterEngine: """ An engine to provide advanced filtering functionality to DID listing requests. """ - def __init__(self, filters, model_class=None, strict_coerce=True): + def __init__( + self, + filters: Union[str, dict[str, Any], list[dict[str, Any]]], + model_class: Optional[type["ModelBase"]] = None, + strict_coerce: bool = True + ): if isinstance(filters, str): - self._filters, _ = parse_did_filter_from_string_fe(filters, omit_name=True) + filters, _ = parse_did_filter_from_string_fe(filters, omit_name=True) elif isinstance(filters, dict): - self._filters = [filters] + filters = [filters] elif isinstance(filters, list): - self._filters = filters + filters = filters else: raise exception.DIDFilterSyntaxError("Input filters are of an unrecognised type.") - self._make_input_backwards_compatible() - self.mandatory_model_attributes = self._translate_filters(model_class=model_class, strict_coerce=strict_coerce) + filters = self._make_input_backwards_compatible(filters=filters) + self._filters, self.mandatory_model_attributes = self._translate_filters(filters=filters, model_class=model_class, strict_coerce=strict_coerce) self._sanity_check_translated_filters() @property - def filters(self): + def filters(self) -> list[list["FilterTuple"]]: return self._filters - def _coerce_filter_word_to_model_attribute(self, word, model_class, strict=True): + def _coerce_filter_word_to_model_attribute(self, word: Any, model_class: Optional[type["ModelBase"]], strict: bool = True) -> Any: """ Attempts to coerce a filter word to an attribute of a . @@ -114,7 +124,7 @@ def _coerce_filter_word_to_model_attribute(self, word, model_class, strict=True) raise exception.KeyNotFound("'{}' keyword could not be coerced to model class attribute. Attribute not found.".format(word)) return word - def _make_input_backwards_compatible(self): + def _make_input_backwards_compatible(self, filters: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Backwards compatibility for previous versions of filtering. @@ -122,13 +132,14 @@ def _make_input_backwards_compatible(self): - converts "created_after" key to "created_at.gte" - converts "created_before" key to "created_at.lte" """ - for or_group in self._filters: + for or_group in filters: if 'created_after' in or_group: or_group['created_at.gte'] = or_group.pop('created_after') elif 'created_before' in or_group: or_group['created_at.lte'] = or_group.pop('created_before') + return filters - def _sanity_check_translated_filters(self): + def _sanity_check_translated_filters(self) -> None: """ Perform a few sanity checks on translated filters. @@ -154,7 +165,7 @@ def _sanity_check_translated_filters(self): raise ValueError("Name operator must be an equality operator.") if key == 'length': # (3) try: - int(value) + int(value) # type: ignore except ValueError: raise ValueError('Length has to be an integer value.') @@ -171,7 +182,12 @@ def _sanity_check_translated_filters(self): if len(set(or_group_test_duplicates)) != len(or_group_test_duplicates): # (6) raise exception.DuplicateCriteriaInDIDFilter() - def _translate_filters(self, model_class, strict_coerce=True): + def _translate_filters( + self, + filters: "Iterable[dict[str, Any]]", + model_class: Optional[type["ModelBase"]], + strict_coerce: bool = True + ) -> tuple[list[list["FilterTuple"]], list[InstrumentedAttribute[Any]]]: """ Reformats filters from: @@ -200,9 +216,10 @@ def _translate_filters(self, model_class, strict_coerce=True): Typecasting of values is also attempted. + :param filters: The filters to translate. :param model_class: The SQL model class. :param strict_coerce: Enforce that keywords must be coercible to a model attribute. - :returns: The set of mandatory model attributes to be used in the filter query. + :returns: The list of translated filters, and the set of mandatory model attributes to be used in the filter query. :raises: MissingModuleException, DIDFilterSyntaxError """ if model_class: @@ -213,7 +230,7 @@ def _translate_filters(self, model_class, strict_coerce=True): mandatory_model_attributes = set() filters_translated = [] - for or_group in self._filters: + for or_group in filters: and_group_parsed = [] for key, value in or_group.items(): # KEY @@ -246,10 +263,9 @@ def _translate_filters(self, model_class, strict_coerce=True): and_group_parsed.append( (key_no_suffix, OPERATORS_CONVERSION_LUT.get(oper), value)) filters_translated.append(and_group_parsed) - self._filters = filters_translated - return list(mandatory_model_attributes) + return filters_translated, list(mandatory_model_attributes) - def _try_typecast_string(self, value): + def _try_typecast_string(self, value: str) -> Union[bool, datetime, float, str]: """ Check if string can be typecasted to bool, datetime or float. @@ -260,11 +276,11 @@ def _try_typecast_string(self, value): value = value.replace('false', 'False').replace('FALSE', 'False') for format in VALID_DATE_FORMATS: # try parsing multiple date formats. try: - value = datetime.strptime(value, format) + typecasted_value = datetime.strptime(value, format) except ValueError: continue else: - return value + return typecasted_value try: operators = ('+', '-', '*', '/') if not any(operator in value for operator in operators): # fix for lax ast literal_eval in earlier python versions @@ -275,7 +291,7 @@ def _try_typecast_string(self, value): def create_mongo_query( self, - additional_filters: Optional[Iterable[tuple[str, Callable, str]]] = None + additional_filters: Optional["Iterable[FilterTuple]"] = None ) -> dict[str, Any]: """ Returns a single mongo query describing the filters expression. @@ -289,9 +305,9 @@ def create_mongo_query( for filter in additional_filters: or_group.append(list(filter)) # type: ignore - or_expressions = [] + or_expressions: list[dict[str, Any]] = [] for or_group in self._filters: - and_expressions = [] + and_expressions: list[dict[str, dict[str, Any]]] = [] for and_group in or_group: key, oper, value = and_group if isinstance(value, str) and any([char in value for char in ['*', '%']]): # wildcards @@ -334,8 +350,8 @@ def create_mongo_query( def create_postgres_query( self, - additional_filters: Optional[Iterable[tuple[str, Callable, str]]] = None, - fixed_table_columns: Iterable[str] = ('scope', 'name', 'vo'), + additional_filters: Optional["Iterable[FilterTuple]"] = None, + fixed_table_columns: Union[tuple[str, ...], dict[str, str]] = ('scope', 'name', 'vo'), jsonb_column: str = 'data' ) -> str: """ @@ -351,9 +367,9 @@ def create_postgres_query( for _filter in additional_filters: or_group.append(list(_filter)) # type: ignore - or_expressions = [] + or_expressions: list[str] = [] for or_group in self._filters: - and_expressions = [] + and_expressions: list[str] = [] for and_group in or_group: key, oper, value = and_group if key in fixed_table_columns: # is this key filtering on a column or in the jsonb? @@ -415,8 +431,8 @@ def create_sqla_query( self, *, session: "Session", - additional_model_attributes: Optional[list[InstrumentedAttribute]] = None, - additional_filters: Optional[Iterable[tuple[str, Callable, str]]] = None, + additional_model_attributes: Optional[list[InstrumentedAttribute[Any]]] = None, + additional_filters: Optional["Iterable[FilterTuple]"] = None, json_column: Optional[InstrumentedAttribute] = None ) -> Query: """ @@ -441,12 +457,12 @@ def create_sqla_query( for _filter in additional_filters: or_group.append(list(_filter)) # type: ignore - or_expressions = [] + or_expressions: list = [] for or_group in self._filters: and_expressions = [] for and_group in or_group: key, oper, value = and_group - if isinstance(key, sqlalchemy.orm.attributes.InstrumentedAttribute): # -> this key filters on a table column. + if isinstance(key, InstrumentedAttribute): # -> this key filters on a table column. if isinstance(value, str) and any([char in value for char in ['*', '%']]): # wildcards if value in ('*', '%', '*', '%'): # match wildcard exactly == no filtering on key continue @@ -511,7 +527,7 @@ def create_sqla_query( or_expressions.append(and_(*and_expressions)) return session.query(*all_model_attributes).filter(or_(*or_expressions)) - def evaluate(self): + def evaluate(self) -> bool: """ Evaluates an expression and returns a boolean result. @@ -526,7 +542,7 @@ def evaluate(self): or_group_evaluations.append(all(and_group_evaluations)) return any(or_group_evaluations) - def print_filters(self): + def print_filters(self) -> str: """ A (more) human readable format of . """ @@ -536,16 +552,16 @@ def print_filters(self): for or_group in self._filters: for and_group in or_group: key, oper, value = and_group - if isinstance(key, sqlalchemy.orm.attributes.InstrumentedAttribute): + if isinstance(key, InstrumentedAttribute): key = and_group[0].key if operators_conversion_LUT_inv[oper] == "": oper = "eq" else: oper = operators_conversion_LUT_inv[oper] - if isinstance(value, sqlalchemy.orm.attributes.InstrumentedAttribute): - value = and_group[2].key + if isinstance(value, InstrumentedAttribute): + value = and_group[2].key # type: ignore elif isinstance(value, DIDType): - value = and_group[2].name + value = and_group[2].name # type: ignore filters = "{}{} {} {}".format(filters, key, oper, value) if and_group != or_group[-1]: filters += ' AND '