Skip to content

Commit

Permalink
Testing: Add type annotations to topology.py; rucio#6588
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio committed Mar 21, 2024
1 parent bed3f05 commit 394b464
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 17 deletions.
4 changes: 3 additions & 1 deletion lib/rucio/common/constants.py
Expand Up @@ -14,6 +14,7 @@

import enum
from collections import namedtuple
from typing import Literal, get_args

from rucio.common.config import config_get_bool

Expand Down Expand Up @@ -47,7 +48,8 @@
SCHEME_MAP['srm'].append('davs')
SCHEME_MAP['davs'].append('srm')

SUPPORTED_PROTOCOLS = ['gsiftp', 'srm', 'root', 'davs', 'http', 'https', 'file', 'storm', 'srm+https', 'scp', 'rsync', 'rclone', 'magnet']
SUPPORTED_PROTOCOLS_LITERAL = Literal['gsiftp', 'srm', 'root', 'davs', 'http', 'https', 'file', 'storm', 'srm+https', 'scp', 'rsync', 'rclone', 'magnet']
SUPPORTED_PROTOCOLS = list(get_args(SUPPORTED_PROTOCOLS_LITERAL))

FTS_STATE = namedtuple('FTS_STATE', ['SUBMITTED', 'READY', 'ACTIVE', 'FAILED', 'FINISHED', 'FINISHEDDIRTY', 'NOT_USED',
'CANCELED'])('SUBMITTED', 'READY', 'ACTIVE', 'FAILED', 'FINISHED', 'FINISHEDDIRTY',
Expand Down
11 changes: 11 additions & 0 deletions lib/rucio/common/types.py
Expand Up @@ -14,6 +14,8 @@

from typing import Any, Callable, Literal, Optional, TypedDict, Union

from rucio.common.constants import SUPPORTED_PROTOCOLS_LITERAL


class InternalType:
'''
Expand Down Expand Up @@ -178,3 +180,12 @@ class RuleDict(TypedDict):
class DIDDict(TypedDict):
name: str
scope: InternalScope


class HopDict(TypedDict):
source_rse_id: str
source_scheme: SUPPORTED_PROTOCOLS_LITERAL
source_scheme_priority: int
dest_rse_id: str
dest_scheme: SUPPORTED_PROTOCOLS_LITERAL
dest_scheme_priority: int
34 changes: 18 additions & 16 deletions lib/rucio/core/topology.py
Expand Up @@ -32,7 +32,6 @@
from rucio.db.sqla.session import read_session, transactional_session
from rucio.rse import rsemanager as rsemgr

LoggerFunction = Callable[..., Any]
_Number = Union[int, Decimal]
TN = TypeVar("TN", bound="Node")
TE = TypeVar("TE", bound="Edge")
Expand All @@ -41,6 +40,9 @@
from typing import Protocol

from sqlalchemy.orm import Session
from typing_extensions import Self

from rucio.common.types import HopDict, LoggerFunction

class _StateProvider(Protocol):
@property
Expand Down Expand Up @@ -81,11 +83,11 @@ def __init__(self, src_node: TN, dst_node: TN):

self.add_to_nodes()

def add_to_nodes(self):
def add_to_nodes(self) -> None:
self.src_node.out_edges[self.dst_node] = self
self.dst_node.in_edges[self.src_node] = self

def remove_from_nodes(self):
def remove_from_nodes(self) -> None:
self.src_node.out_edges.pop(self.dst_node, None)
self.dst_node.in_edges.pop(self.src_node, None)

Expand All @@ -105,12 +107,12 @@ def dst_node(self) -> TN:
raise ReferenceError("weak reference returned None")
return node

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
return self._src_node == other._src_node and self._dst_node == other._dst_node

def __str__(self):
def __str__(self) -> str:
return f'{self._src_node}-->{self._dst_node}'


Expand All @@ -127,7 +129,7 @@ def __init__(
):
super().__init__(rse_ids=rse_ids, rse_data_cls=node_cls)
self._edge_cls = edge_cls
self._edges = {}
self._edges: dict[tuple[TN, TN], TE] = {}
self._edges_loaded = False
self._multihop_nodes = set()
self._hop_penalty = DEFAULT_HOP_PENALTY
Expand All @@ -148,7 +150,7 @@ def ensure_loaded(
include_deleted: bool = False,
*,
session: "Session",
):
) -> None:

if not rse_ids:
with self._lock:
Expand Down Expand Up @@ -177,7 +179,7 @@ def get_or_create(self, rse_id: str) -> "TN":
return rse_data

@property
def edges(self):
def edges(self) -> dict[tuple[TN, TN], TE]:
with self._lock:
return copy.copy(self._edges)

Expand All @@ -193,7 +195,7 @@ def get_or_create_edge(self, src_node: TN, dst_node: TN) -> "TE":
self._edges[src_node, dst_node] = edge = self._edge_cls(src_node, dst_node)
return edge

def delete_edge(self, src_node: TN, dst_node: TN):
def delete_edge(self, src_node: TN, dst_node: TN) -> None:
with self._lock:
edge = self._edges[src_node, dst_node]
edge.remove_from_nodes()
Expand All @@ -203,11 +205,11 @@ def multihop_enabled(self) -> bool:
return True if self._multihop_nodes else False

@read_session
def configure_multihop(self, multihop_rse_ids: Optional[set[str]] = None, *, session: "Session", logger: LoggerFunction = logging.log):
def configure_multihop(self, multihop_rse_ids: Optional[set[str]] = None, *, session: "Session", logger: "LoggerFunction" = logging.log) -> "Self":
with self._lock:
return self._configure_multihop(multihop_rse_ids=multihop_rse_ids, session=session, logger=logger)

def _configure_multihop(self, multihop_rse_ids: Optional[set[str]] = None, *, session: "Session", logger: LoggerFunction = logging.log):
def _configure_multihop(self, multihop_rse_ids: Optional[set[str]] = None, *, session: "Session", logger: "LoggerFunction" = logging.log) -> "Self":

if multihop_rse_ids is None:
multihop_rse_expression = config_get('transfers', 'multihop_rse_expression', default='available_for_multihop=true', expiration_time=600, session=session)
Expand Down Expand Up @@ -236,7 +238,7 @@ def _configure_multihop(self, multihop_rse_ids: Optional[set[str]] = None, *, se
return self

@read_session
def ensure_edges_loaded(self, *, session: "Session"):
def ensure_edges_loaded(self, *, session: "Session") -> None:
"""
Ensure that all edges are loaded for the (sub-)set of nodes known by this topology object
"""
Expand All @@ -246,7 +248,7 @@ def ensure_edges_loaded(self, *, session: "Session"):
with self._lock:
return self._ensure_edges_loaded(session=session)

def _ensure_edges_loaded(self, *, session: "Session"):
def _ensure_edges_loaded(self, *, session: "Session") -> None:
stmt = select(
models.Distance
).where(
Expand Down Expand Up @@ -281,7 +283,7 @@ def _ensure_edges_loaded(self, *, session: "Session"):
@read_session
def search_shortest_paths(
self,
src_nodes: list[TN],
src_nodes: Iterable[TN],
dst_node: TN,
operation_src: str,
operation_dest: str,
Expand Down Expand Up @@ -434,7 +436,7 @@ def __init__(self, ttl, new_obj_fnc):
self._new_obj_fnc = new_obj_fnc
self._ttl = ttl

def get(self, logger=logging.log):
def get(self, logger: "LoggerFunction" = logging.log) -> object:
with self._lock:
if not self._object \
or not self._creation_time \
Expand All @@ -452,7 +454,7 @@ def get_hops(
multihop_rse_ids: Optional[set[str]] = None,
limit_dest_schemes: Optional[list[str]] = None,
*, session: "Session",
):
) -> list["HopDict"]:
"""
Get a list of hops needed to transfer date from source_rse_id to dest_rse_id.
Ideally, the list will only include one item (dest_rse_id) since no hops are needed.
Expand Down

0 comments on commit 394b464

Please sign in to comment.