Skip to content

Commit

Permalink
Testing: Add type annotations for dumper functions; #6588
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio authored and bari12 committed Apr 19, 2024
1 parent 48ba197 commit 1409f6f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 22 deletions.
57 changes: 36 additions & 21 deletions lib/rucio/common/dumper/__init__.py
Expand Up @@ -21,6 +21,7 @@
import re
import sys
import tempfile
from typing import TYPE_CHECKING, Any, Literal, Optional, Union

import gfal2
import magic
Expand All @@ -29,29 +30,39 @@
from rucio.common import config
from rucio.core.rse import get_rse_id, get_rse_protocols

if TYPE_CHECKING:
from collections.abc import Iterator
from multiprocessing.connection import Connection
from types import ModuleType
from typing import IO, TextIO

from _typeshed import FileDescriptorOrPath, GenericPath, StrOrBytesPath

from rucio.common.types import RSEProtocolDict


class HTTPDownloadFailed(Exception):
def __init__(self, msg='', code=None):
def __init__(self, msg: str = '', code: Optional[str] = None):
self.code = code
if code is not None:
msg = '{0} (Status {1})'.format(msg, code)
super(HTTPDownloadFailed, self).__init__(msg)


class LogPipeHandler(logging.Handler):
def __init__(self, pipe):
def __init__(self, pipe: "Connection"):
super(LogPipeHandler, self).__init__()
self.pipe = pipe

def emit(self, record):
def emit(self, record: logging.LogRecord) -> None:
self.pipe.send(self.format(record))

def close(self):
def close(self) -> None:
super(LogPipeHandler, self).close()
self.pipe.close()


def error(text, exit_code=1):
def error(text: str, exit_code: int = 1) -> None:
'''
Log and print `text` error. This function ends the execution of the program with exit code
`exit_code` (defaults to 1).
Expand All @@ -62,7 +73,7 @@ def error(text, exit_code=1):
exit(1)


def mkdir(dir_):
def mkdir(dir_: "StrOrBytesPath") -> None:
'''
This functions creates the `dir_` directory if it doesn't exist. If `dir_`
already exists this function does nothing.
Expand All @@ -73,7 +84,7 @@ def mkdir(dir_):
assert error.errno == 17


def cacert_config(config, rucio_home):
def cacert_config(config: "ModuleType", rucio_home: str) -> Optional[Union["FileDescriptorOrPath", Literal[False]]]:
logger = logging.getLogger('dumper.__init__')
try:
cacert = config.config_get('client', 'ca_cert').replace('$RUCIO_HOME', rucio_home)
Expand All @@ -87,13 +98,13 @@ def cacert_config(config, rucio_home):
return cacert


def rucio_home():
def rucio_home() -> str:
return os.environ.get('RUCIO_HOME', '/opt/rucio')


def get_requests_session():
def get_requests_session() -> requests.Session:
requests_session = requests.Session()
requests_session.verify = cacert_config(config, rucio_home())
requests_session.verify = cacert_config(config, rucio_home()) # type: ignore
requests_session.stream = True
return requests_session

Expand All @@ -115,7 +126,7 @@ def get_requests_session():
# pylint: enable=no-member


def isplaintext(filename):
def isplaintext(filename: "GenericPath") -> bool:
'''
Returns True if `filename` has mimetype == 'text/plain'.
'''
Expand All @@ -124,7 +135,7 @@ def isplaintext(filename):
return mimetype(filename).split(';')[0] == 'text/plain'


def smart_open(filename):
def smart_open(filename: "GenericPath") -> Optional[Union["TextIO", gzip.GzipFile]]:
'''
Returns an open file object if `filename` is plain text, else assumes
it is a bzip2 compressed file and returns a file-like object to
Expand All @@ -145,7 +156,11 @@ def smart_open(filename):


@contextlib.contextmanager
def temp_file(directory, final_name=None, binary=False):
def temp_file(
directory: str,
final_name: Optional[str] = None,
binary: bool = False
) -> "Iterator[tuple[IO[Any], StrOrBytesPath]]":
'''
Allows to create a temporal file to store partial results, when the
file is complete it is renamed to `final_name`.
Expand Down Expand Up @@ -192,7 +207,7 @@ def temp_file(directory, final_name=None, binary=False):
MILLISECONDS_RE = re.compile(r'\.(\d{3})Z$')


def to_datetime(str_or_datetime):
def to_datetime(str_or_datetime: Union[datetime.datetime, str]) -> Optional[datetime.datetime]:
"""
Convert string to datetime. The format is somewhat flexible.
Timezone information is ignored.
Expand Down Expand Up @@ -230,17 +245,17 @@ def to_datetime(str_or_datetime):
return date


def ddmendpoint_preferred_protocol(ddmendpoint):
def ddmendpoint_preferred_protocol(ddmendpoint: str) -> "RSEProtocolDict":
return next(p for p in get_rse_protocols(get_rse_id(ddmendpoint))['protocols'] if p['domains']['wan']['read'] == 1)


def ddmendpoint_url(ddmendpoint):
def ddmendpoint_url(ddmendpoint: str) -> str:
preferred_protocol = ddmendpoint_preferred_protocol(ddmendpoint)
prefix = re.sub(r'rucio/$', '', preferred_protocol['prefix'])
return '{scheme}://{hostname}:{port}'.format(**preferred_protocol) + prefix


def http_download_to_file(url, file_, session=None):
def http_download_to_file(url: str, file_: "IO", session: Optional[requests.Session] = None) -> None:
'''
Download the file in `url` storing it in the `file_` file-like
object.
Expand All @@ -259,7 +274,7 @@ def _do_download(url, file_, session, try_decode=False):
url,
response.status_code,
)
raise HTTPDownloadFailed('Error downloading ' + url, response.status_code)
raise HTTPDownloadFailed('Error downloading ' + url, str(response.status_code))

if try_decode:
if response.encoding is None:
Expand All @@ -278,15 +293,15 @@ def _do_download(url, file_, session, try_decode=False):
_do_download(url, file_, session, True)


def http_download(url, filename):
def http_download(url: str, filename: "FileDescriptorOrPath") -> None:
'''
Download the file in `url` storing it in the path given by `filename`.
'''
with open(filename, 'w') as f:
http_download_to_file(url, f)


def gfal_download_to_file(url, file_):
def gfal_download_to_file(url: str, file_: "IO") -> None:
'''
Download the file in `url` storing it in the `file_` file-like
object.
Expand All @@ -311,7 +326,7 @@ def gfal_download_to_file(url, file_):
chunk = infile.read(CHUNK_SIZE)


def gfal_download(url, filename):
def gfal_download(url: str, filename: "FileDescriptorOrPath") -> None:
'''
Download the file in `url` storing it in the path given by `filename`.
'''
Expand Down
2 changes: 1 addition & 1 deletion lib/rucio/common/dumper/data_models.py
Expand Up @@ -186,7 +186,7 @@ def download(cls, rse, date='latest', cache_dir=DUMPS_CACHE_DIR):
url,
response.status_code,
)
raise HTTPDownloadFailed('Downloading {0} dump'.format(cls.__name__), code=response.status_code)
raise HTTPDownloadFailed('Downloading {0} dump'.format(cls.__name__), code=str(response.status_code))

with temp_file(cache_dir, final_name=filename) as (tfile, _):
http_download_to_file(url, tfile, session=requests_session)
Expand Down

0 comments on commit 1409f6f

Please sign in to comment.