From 6eb8cf8157a87b8a350ceadfe9f2579f749ce47f Mon Sep 17 00:00:00 2001 From: rdimaio Date: Wed, 27 Mar 2024 14:22:57 +0100 Subject: [PATCH] Testing: Add type annotations for dumper functions; #6588 --- lib/rucio/common/dumper/__init__.py | 57 ++++++++++++++++---------- lib/rucio/common/dumper/data_models.py | 2 +- 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/lib/rucio/common/dumper/__init__.py b/lib/rucio/common/dumper/__init__.py index eda86cf34d..54ba837fad 100644 --- a/lib/rucio/common/dumper/__init__.py +++ b/lib/rucio/common/dumper/__init__.py @@ -21,6 +21,7 @@ import re import sys import tempfile +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import gfal2 import magic @@ -29,9 +30,19 @@ from rucio.common import config from rucio.core.rse import get_rse_id, get_rse_protocols +if TYPE_CHECKING: + from collections.abc import Iterator + from multiprocessing.connection import Connection + from types import ModuleType + from typing import IO, TextIO + + from _typeshed import FileDescriptorOrPath, GenericPath, StrOrBytesPath + + from rucio.common.types import RSEProtocolDict + class HTTPDownloadFailed(Exception): - def __init__(self, msg='', code=None): + def __init__(self, msg: str = '', code: Optional[str] = None): self.code = code if code is not None: msg = '{0} (Status {1})'.format(msg, code) @@ -39,19 +50,19 @@ def __init__(self, msg='', code=None): class LogPipeHandler(logging.Handler): - def __init__(self, pipe): + def __init__(self, pipe: "Connection"): super(LogPipeHandler, self).__init__() self.pipe = pipe - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: self.pipe.send(self.format(record)) - def close(self): + def close(self) -> None: super(LogPipeHandler, self).close() self.pipe.close() -def error(text, exit_code=1): +def error(text: str, exit_code: int = 1) -> None: ''' Log and print `text` error. This function ends the execution of the program with exit code `exit_code` (defaults to 1). @@ -62,7 +73,7 @@ def error(text, exit_code=1): exit(1) -def mkdir(dir_): +def mkdir(dir_: "StrOrBytesPath") -> None: ''' This functions creates the `dir_` directory if it doesn't exist. If `dir_` already exists this function does nothing. @@ -73,7 +84,7 @@ def mkdir(dir_): assert error.errno == 17 -def cacert_config(config, rucio_home): +def cacert_config(config: "ModuleType", rucio_home: str) -> Optional[Union["FileDescriptorOrPath", Literal[False]]]: logger = logging.getLogger('dumper.__init__') try: cacert = config.config_get('client', 'ca_cert').replace('$RUCIO_HOME', rucio_home) @@ -87,13 +98,13 @@ def cacert_config(config, rucio_home): return cacert -def rucio_home(): +def rucio_home() -> str: return os.environ.get('RUCIO_HOME', '/opt/rucio') -def get_requests_session(): +def get_requests_session() -> requests.Session: requests_session = requests.Session() - requests_session.verify = cacert_config(config, rucio_home()) + requests_session.verify = cacert_config(config, rucio_home()) # type: ignore requests_session.stream = True return requests_session @@ -115,7 +126,7 @@ def get_requests_session(): # pylint: enable=no-member -def isplaintext(filename): +def isplaintext(filename: "GenericPath") -> bool: ''' Returns True if `filename` has mimetype == 'text/plain'. ''' @@ -124,7 +135,7 @@ def isplaintext(filename): return mimetype(filename).split(';')[0] == 'text/plain' -def smart_open(filename): +def smart_open(filename: "GenericPath") -> Optional[Union["TextIO", gzip.GzipFile]]: ''' Returns an open file object if `filename` is plain text, else assumes it is a bzip2 compressed file and returns a file-like object to @@ -145,7 +156,11 @@ def smart_open(filename): @contextlib.contextmanager -def temp_file(directory, final_name=None, binary=False): +def temp_file( + directory: str, + final_name: Optional[str] = None, + binary: bool = False +) -> "Iterator[tuple[IO[Any], StrOrBytesPath]]": ''' Allows to create a temporal file to store partial results, when the file is complete it is renamed to `final_name`. @@ -192,7 +207,7 @@ def temp_file(directory, final_name=None, binary=False): MILLISECONDS_RE = re.compile(r'\.(\d{3})Z$') -def to_datetime(str_or_datetime): +def to_datetime(str_or_datetime: Union[datetime.datetime, str]) -> Optional[datetime.datetime]: """ Convert string to datetime. The format is somewhat flexible. Timezone information is ignored. @@ -230,17 +245,17 @@ def to_datetime(str_or_datetime): return date -def ddmendpoint_preferred_protocol(ddmendpoint): +def ddmendpoint_preferred_protocol(ddmendpoint: str) -> "RSEProtocolDict": return next(p for p in get_rse_protocols(get_rse_id(ddmendpoint))['protocols'] if p['domains']['wan']['read'] == 1) -def ddmendpoint_url(ddmendpoint): +def ddmendpoint_url(ddmendpoint: str) -> str: preferred_protocol = ddmendpoint_preferred_protocol(ddmendpoint) prefix = re.sub(r'rucio/$', '', preferred_protocol['prefix']) return '{scheme}://{hostname}:{port}'.format(**preferred_protocol) + prefix -def http_download_to_file(url, file_, session=None): +def http_download_to_file(url: str, file_: "IO", session: Optional[requests.Session] = None) -> None: ''' Download the file in `url` storing it in the `file_` file-like object. @@ -259,7 +274,7 @@ def _do_download(url, file_, session, try_decode=False): url, response.status_code, ) - raise HTTPDownloadFailed('Error downloading ' + url, response.status_code) + raise HTTPDownloadFailed('Error downloading ' + url, str(response.status_code)) if try_decode: if response.encoding is None: @@ -278,7 +293,7 @@ def _do_download(url, file_, session, try_decode=False): _do_download(url, file_, session, True) -def http_download(url, filename): +def http_download(url: str, filename: "FileDescriptorOrPath") -> None: ''' Download the file in `url` storing it in the path given by `filename`. ''' @@ -286,7 +301,7 @@ def http_download(url, filename): http_download_to_file(url, f) -def gfal_download_to_file(url, file_): +def gfal_download_to_file(url: str, file_: "IO") -> None: ''' Download the file in `url` storing it in the `file_` file-like object. @@ -311,7 +326,7 @@ def gfal_download_to_file(url, file_): chunk = infile.read(CHUNK_SIZE) -def gfal_download(url, filename): +def gfal_download(url: str, filename: "FileDescriptorOrPath") -> None: ''' Download the file in `url` storing it in the path given by `filename`. ''' diff --git a/lib/rucio/common/dumper/data_models.py b/lib/rucio/common/dumper/data_models.py index dfb4f20022..20629e1d3b 100644 --- a/lib/rucio/common/dumper/data_models.py +++ b/lib/rucio/common/dumper/data_models.py @@ -186,7 +186,7 @@ def download(cls, rse, date='latest', cache_dir=DUMPS_CACHE_DIR): url, response.status_code, ) - raise HTTPDownloadFailed('Downloading {0} dump'.format(cls.__name__), code=response.status_code) + raise HTTPDownloadFailed('Downloading {0} dump'.format(cls.__name__), code=str(response.status_code)) with temp_file(cache_dir, final_name=filename) as (tfile, _): http_download_to_file(url, tfile, session=requests_session)