diff --git a/lib/rucio/common/constants.py b/lib/rucio/common/constants.py index 9f5ebadbe2..0c7037410a 100644 --- a/lib/rucio/common/constants.py +++ b/lib/rucio/common/constants.py @@ -14,6 +14,7 @@ import enum from collections import namedtuple +from typing import Literal, get_args from rucio.common.config import config_get_bool @@ -47,7 +48,8 @@ SCHEME_MAP['srm'].append('davs') SCHEME_MAP['davs'].append('srm') -SUPPORTED_PROTOCOLS = ['gsiftp', 'srm', 'root', 'davs', 'http', 'https', 'file', 'storm', 'srm+https', 'scp', 'rsync', 'rclone', 'magnet'] +SUPPORTED_PROTOCOLS_LITERAL = Literal['gsiftp', 'srm', 'root', 'davs', 'http', 'https', 'file', 'storm', 'srm+https', 'scp', 'rsync', 'rclone', 'magnet'] +SUPPORTED_PROTOCOLS = list(get_args(SUPPORTED_PROTOCOLS_LITERAL)) FTS_STATE = namedtuple('FTS_STATE', ['SUBMITTED', 'READY', 'ACTIVE', 'FAILED', 'FINISHED', 'FINISHEDDIRTY', 'NOT_USED', 'CANCELED'])('SUBMITTED', 'READY', 'ACTIVE', 'FAILED', 'FINISHED', 'FINISHEDDIRTY', diff --git a/lib/rucio/common/types.py b/lib/rucio/common/types.py index d085839695..c648e8782e 100644 --- a/lib/rucio/common/types.py +++ b/lib/rucio/common/types.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypedDict, Union + +if TYPE_CHECKING: + from rucio.common.constants import SUPPORTED_PROTOCOLS_LITERAL class InternalType: @@ -178,3 +181,12 @@ class RuleDict(TypedDict): class DIDDict(TypedDict): name: str scope: InternalScope + + +class HopDict(TypedDict): + source_rse_id: str + source_scheme: "SUPPORTED_PROTOCOLS_LITERAL" + source_scheme_priority: int + dest_rse_id: str + dest_scheme: "SUPPORTED_PROTOCOLS_LITERAL" + dest_scheme_priority: int diff --git a/lib/rucio/core/topology.py b/lib/rucio/core/topology.py index 56b141f24e..d523cc3c7b 100644 --- a/lib/rucio/core/topology.py +++ b/lib/rucio/core/topology.py @@ -32,7 +32,6 @@ from rucio.db.sqla.session import read_session, transactional_session from rucio.rse import rsemanager as rsemgr -LoggerFunction = Callable[..., Any] _Number = Union[int, Decimal] TN = TypeVar("TN", bound="Node") TE = TypeVar("TE", bound="Edge") @@ -41,6 +40,9 @@ from typing import Protocol from sqlalchemy.orm import Session + from typing_extensions import Self + + from rucio.common.types import HopDict, LoggerFunction class _StateProvider(Protocol): @property @@ -81,11 +83,11 @@ def __init__(self, src_node: TN, dst_node: TN): self.add_to_nodes() - def add_to_nodes(self): + def add_to_nodes(self) -> None: self.src_node.out_edges[self.dst_node] = self self.dst_node.in_edges[self.src_node] = self - def remove_from_nodes(self): + def remove_from_nodes(self) -> None: self.src_node.out_edges.pop(self.dst_node, None) self.dst_node.in_edges.pop(self.src_node, None) @@ -105,12 +107,12 @@ def dst_node(self) -> TN: raise ReferenceError("weak reference returned None") return node - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return False return self._src_node == other._src_node and self._dst_node == other._dst_node - def __str__(self): + def __str__(self) -> str: return f'{self._src_node}-->{self._dst_node}' @@ -127,7 +129,7 @@ def __init__( ): super().__init__(rse_ids=rse_ids, rse_data_cls=node_cls) self._edge_cls = edge_cls - self._edges = {} + self._edges: dict[tuple[TN, TN], TE] = {} self._edges_loaded = False self._multihop_nodes = set() self._hop_penalty = DEFAULT_HOP_PENALTY @@ -148,7 +150,7 @@ def ensure_loaded( include_deleted: bool = False, *, session: "Session", - ): + ) -> None: if not rse_ids: with self._lock: @@ -177,7 +179,7 @@ def get_or_create(self, rse_id: str) -> "TN": return rse_data @property - def edges(self): + def edges(self) -> dict[tuple[TN, TN], TE]: with self._lock: return copy.copy(self._edges) @@ -193,7 +195,7 @@ def get_or_create_edge(self, src_node: TN, dst_node: TN) -> "TE": self._edges[src_node, dst_node] = edge = self._edge_cls(src_node, dst_node) return edge - def delete_edge(self, src_node: TN, dst_node: TN): + def delete_edge(self, src_node: TN, dst_node: TN) -> None: with self._lock: edge = self._edges[src_node, dst_node] edge.remove_from_nodes() @@ -203,11 +205,11 @@ def multihop_enabled(self) -> bool: return True if self._multihop_nodes else False @read_session - def configure_multihop(self, multihop_rse_ids: Optional[set[str]] = None, *, session: "Session", logger: LoggerFunction = logging.log): + def configure_multihop(self, multihop_rse_ids: Optional[set[str]] = None, *, session: "Session", logger: "LoggerFunction" = logging.log) -> "Self": with self._lock: return self._configure_multihop(multihop_rse_ids=multihop_rse_ids, session=session, logger=logger) - def _configure_multihop(self, multihop_rse_ids: Optional[set[str]] = None, *, session: "Session", logger: LoggerFunction = logging.log): + def _configure_multihop(self, multihop_rse_ids: Optional[set[str]] = None, *, session: "Session", logger: "LoggerFunction" = logging.log) -> "Self": if multihop_rse_ids is None: multihop_rse_expression = config_get('transfers', 'multihop_rse_expression', default='available_for_multihop=true', expiration_time=600, session=session) @@ -236,7 +238,7 @@ def _configure_multihop(self, multihop_rse_ids: Optional[set[str]] = None, *, se return self @read_session - def ensure_edges_loaded(self, *, session: "Session"): + def ensure_edges_loaded(self, *, session: "Session") -> None: """ Ensure that all edges are loaded for the (sub-)set of nodes known by this topology object """ @@ -246,7 +248,7 @@ def ensure_edges_loaded(self, *, session: "Session"): with self._lock: return self._ensure_edges_loaded(session=session) - def _ensure_edges_loaded(self, *, session: "Session"): + def _ensure_edges_loaded(self, *, session: "Session") -> None: stmt = select( models.Distance ).where( @@ -281,7 +283,7 @@ def _ensure_edges_loaded(self, *, session: "Session"): @read_session def search_shortest_paths( self, - src_nodes: list[TN], + src_nodes: Iterable[TN], dst_node: TN, operation_src: str, operation_dest: str, @@ -434,7 +436,7 @@ def __init__(self, ttl, new_obj_fnc): self._new_obj_fnc = new_obj_fnc self._ttl = ttl - def get(self, logger=logging.log): + def get(self, logger: "LoggerFunction" = logging.log) -> object: with self._lock: if not self._object \ or not self._creation_time \ @@ -452,7 +454,7 @@ def get_hops( multihop_rse_ids: Optional[set[str]] = None, limit_dest_schemes: Optional[list[str]] = None, *, session: "Session", -): +) -> list["HopDict"]: """ Get a list of hops needed to transfer date from source_rse_id to dest_rse_id. Ideally, the list will only include one item (dest_rse_id) since no hops are needed.