Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing: Add type annotations for dumper functions #6614

Merged
merged 1 commit into from Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 36 additions & 21 deletions lib/rucio/common/dumper/__init__.py
Expand Up @@ -21,6 +21,7 @@
import re
import sys
import tempfile
from typing import TYPE_CHECKING, Any, Literal, Optional, Union

import gfal2
import magic
Expand All @@ -29,29 +30,39 @@
from rucio.common import config
from rucio.core.rse import get_rse_id, get_rse_protocols

if TYPE_CHECKING:
from collections.abc import Iterator
from multiprocessing.connection import Connection
from types import ModuleType
from typing import IO, TextIO

from _typeshed import FileDescriptorOrPath, GenericPath, StrOrBytesPath

from rucio.common.types import RSEProtocolDict


class HTTPDownloadFailed(Exception):
def __init__(self, msg='', code=None):
def __init__(self, msg: str = '', code: Optional[str] = None):
self.code = code
if code is not None:
msg = '{0} (Status {1})'.format(msg, code)
super(HTTPDownloadFailed, self).__init__(msg)


class LogPipeHandler(logging.Handler):
def __init__(self, pipe):
def __init__(self, pipe: "Connection"):
super(LogPipeHandler, self).__init__()
self.pipe = pipe

def emit(self, record):
def emit(self, record: logging.LogRecord) -> None:
self.pipe.send(self.format(record))

def close(self):
def close(self) -> None:
super(LogPipeHandler, self).close()
self.pipe.close()


def error(text, exit_code=1):
def error(text: str, exit_code: int = 1) -> None:
'''
Log and print `text` error. This function ends the execution of the program with exit code
`exit_code` (defaults to 1).
Expand All @@ -62,7 +73,7 @@ def error(text, exit_code=1):
exit(1)


def mkdir(dir_):
def mkdir(dir_: "StrOrBytesPath") -> None:
'''
This functions creates the `dir_` directory if it doesn't exist. If `dir_`
already exists this function does nothing.
Expand All @@ -73,7 +84,7 @@ def mkdir(dir_):
assert error.errno == 17


def cacert_config(config, rucio_home):
def cacert_config(config: "ModuleType", rucio_home: str) -> Optional[Union["FileDescriptorOrPath", Literal[False]]]:
logger = logging.getLogger('dumper.__init__')
try:
cacert = config.config_get('client', 'ca_cert').replace('$RUCIO_HOME', rucio_home)
Expand All @@ -87,13 +98,13 @@ def cacert_config(config, rucio_home):
return cacert


def rucio_home():
def rucio_home() -> str:
return os.environ.get('RUCIO_HOME', '/opt/rucio')


def get_requests_session():
def get_requests_session() -> requests.Session:
requests_session = requests.Session()
requests_session.verify = cacert_config(config, rucio_home())
requests_session.verify = cacert_config(config, rucio_home()) # type: ignore
requests_session.stream = True
return requests_session

Expand All @@ -115,7 +126,7 @@ def get_requests_session():
# pylint: enable=no-member


def isplaintext(filename):
def isplaintext(filename: "GenericPath") -> bool:
'''
Returns True if `filename` has mimetype == 'text/plain'.
'''
Expand All @@ -124,7 +135,7 @@ def isplaintext(filename):
return mimetype(filename).split(';')[0] == 'text/plain'


def smart_open(filename):
def smart_open(filename: "GenericPath") -> Optional[Union["TextIO", gzip.GzipFile]]:
'''
Returns an open file object if `filename` is plain text, else assumes
it is a bzip2 compressed file and returns a file-like object to
Expand All @@ -145,7 +156,11 @@ def smart_open(filename):


@contextlib.contextmanager
def temp_file(directory, final_name=None, binary=False):
def temp_file(
directory: str,
final_name: Optional[str] = None,
binary: bool = False
) -> "Iterator[tuple[IO[Any], StrOrBytesPath]]":
'''
Allows to create a temporal file to store partial results, when the
file is complete it is renamed to `final_name`.
Expand Down Expand Up @@ -192,7 +207,7 @@ def temp_file(directory, final_name=None, binary=False):
MILLISECONDS_RE = re.compile(r'\.(\d{3})Z$')


def to_datetime(str_or_datetime):
def to_datetime(str_or_datetime: Union[datetime.datetime, str]) -> Optional[datetime.datetime]:
"""
Convert string to datetime. The format is somewhat flexible.
Timezone information is ignored.
Expand Down Expand Up @@ -230,17 +245,17 @@ def to_datetime(str_or_datetime):
return date


def ddmendpoint_preferred_protocol(ddmendpoint):
def ddmendpoint_preferred_protocol(ddmendpoint: str) -> "RSEProtocolDict":
return next(p for p in get_rse_protocols(get_rse_id(ddmendpoint))['protocols'] if p['domains']['wan']['read'] == 1)


def ddmendpoint_url(ddmendpoint):
def ddmendpoint_url(ddmendpoint: str) -> str:
preferred_protocol = ddmendpoint_preferred_protocol(ddmendpoint)
prefix = re.sub(r'rucio/$', '', preferred_protocol['prefix'])
return '{scheme}://{hostname}:{port}'.format(**preferred_protocol) + prefix


def http_download_to_file(url, file_, session=None):
def http_download_to_file(url: str, file_: "IO", session: Optional[requests.Session] = None) -> None:
'''
Download the file in `url` storing it in the `file_` file-like
object.
Expand All @@ -259,7 +274,7 @@ def _do_download(url, file_, session, try_decode=False):
url,
response.status_code,
)
raise HTTPDownloadFailed('Error downloading ' + url, response.status_code)
raise HTTPDownloadFailed('Error downloading ' + url, str(response.status_code))

if try_decode:
if response.encoding is None:
Expand All @@ -278,15 +293,15 @@ def _do_download(url, file_, session, try_decode=False):
_do_download(url, file_, session, True)


def http_download(url, filename):
def http_download(url: str, filename: "FileDescriptorOrPath") -> None:
'''
Download the file in `url` storing it in the path given by `filename`.
'''
with open(filename, 'w') as f:
http_download_to_file(url, f)


def gfal_download_to_file(url, file_):
def gfal_download_to_file(url: str, file_: "IO") -> None:
'''
Download the file in `url` storing it in the `file_` file-like
object.
Expand All @@ -311,7 +326,7 @@ def gfal_download_to_file(url, file_):
chunk = infile.read(CHUNK_SIZE)


def gfal_download(url, filename):
def gfal_download(url: str, filename: "FileDescriptorOrPath") -> None:
'''
Download the file in `url` storing it in the path given by `filename`.
'''
Expand Down
2 changes: 1 addition & 1 deletion lib/rucio/common/dumper/data_models.py
Expand Up @@ -186,7 +186,7 @@ def download(cls, rse, date='latest', cache_dir=DUMPS_CACHE_DIR):
url,
response.status_code,
)
raise HTTPDownloadFailed('Downloading {0} dump'.format(cls.__name__), code=response.status_code)
raise HTTPDownloadFailed('Downloading {0} dump'.format(cls.__name__), code=str(response.status_code))

with temp_file(cache_dir, final_name=filename) as (tfile, _):
http_download_to_file(url, tfile, session=requests_session)
Expand Down