Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing: Add type annotations to topology.py #6589

Merged
merged 1 commit into from Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion lib/rucio/common/constants.py
Expand Up @@ -14,6 +14,7 @@

import enum
from collections import namedtuple
from typing import Literal, get_args

from rucio.common.config import config_get_bool

Expand Down Expand Up @@ -47,7 +48,8 @@
SCHEME_MAP['srm'].append('davs')
SCHEME_MAP['davs'].append('srm')

SUPPORTED_PROTOCOLS = ['gsiftp', 'srm', 'root', 'davs', 'http', 'https', 'file', 'storm', 'srm+https', 'scp', 'rsync', 'rclone', 'magnet']
SUPPORTED_PROTOCOLS_LITERAL = Literal['gsiftp', 'srm', 'root', 'davs', 'http', 'https', 'file', 'storm', 'srm+https', 'scp', 'rsync', 'rclone', 'magnet']
SUPPORTED_PROTOCOLS = list(get_args(SUPPORTED_PROTOCOLS_LITERAL))

FTS_STATE = namedtuple('FTS_STATE', ['SUBMITTED', 'READY', 'ACTIVE', 'FAILED', 'FINISHED', 'FINISHEDDIRTY', 'NOT_USED',
'CANCELED'])('SUBMITTED', 'READY', 'ACTIVE', 'FAILED', 'FINISHED', 'FINISHEDDIRTY',
Expand Down
14 changes: 13 additions & 1 deletion lib/rucio/common/types.py
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Literal, Optional, TypedDict, Union
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypedDict, Union

if TYPE_CHECKING:
from rucio.common.constants import SUPPORTED_PROTOCOLS_LITERAL


class InternalType:
Expand Down Expand Up @@ -178,3 +181,12 @@ class RuleDict(TypedDict):
class DIDDict(TypedDict):
name: str
scope: InternalScope


class HopDict(TypedDict):
source_rse_id: str
source_scheme: "SUPPORTED_PROTOCOLS_LITERAL"
source_scheme_priority: int
dest_rse_id: str
dest_scheme: "SUPPORTED_PROTOCOLS_LITERAL"
dest_scheme_priority: int
34 changes: 18 additions & 16 deletions lib/rucio/core/topology.py
Expand Up @@ -32,7 +32,6 @@
from rucio.db.sqla.session import read_session, transactional_session
from rucio.rse import rsemanager as rsemgr

LoggerFunction = Callable[..., Any]
_Number = Union[int, Decimal]
TN = TypeVar("TN", bound="Node")
TE = TypeVar("TE", bound="Edge")
Expand All @@ -41,6 +40,9 @@
from typing import Protocol

from sqlalchemy.orm import Session
from typing_extensions import Self

from rucio.common.types import HopDict, LoggerFunction

class _StateProvider(Protocol):
@property
Expand Down Expand Up @@ -81,11 +83,11 @@ def __init__(self, src_node: TN, dst_node: TN):

self.add_to_nodes()

def add_to_nodes(self):
def add_to_nodes(self) -> None:
self.src_node.out_edges[self.dst_node] = self
self.dst_node.in_edges[self.src_node] = self

def remove_from_nodes(self):
def remove_from_nodes(self) -> None:
self.src_node.out_edges.pop(self.dst_node, None)
self.dst_node.in_edges.pop(self.src_node, None)

Expand All @@ -105,12 +107,12 @@ def dst_node(self) -> TN:
raise ReferenceError("weak reference returned None")
return node

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
return self._src_node == other._src_node and self._dst_node == other._dst_node

def __str__(self):
def __str__(self) -> str:
return f'{self._src_node}-->{self._dst_node}'


Expand All @@ -127,7 +129,7 @@ def __init__(
):
super().__init__(rse_ids=rse_ids, rse_data_cls=node_cls)
self._edge_cls = edge_cls
self._edges = {}
self._edges: dict[tuple[TN, TN], TE] = {}
self._edges_loaded = False
self._multihop_nodes = set()
self._hop_penalty = DEFAULT_HOP_PENALTY
Expand All @@ -148,7 +150,7 @@ def ensure_loaded(
include_deleted: bool = False,
*,
session: "Session",
):
) -> None:

if not rse_ids:
with self._lock:
Expand Down Expand Up @@ -177,7 +179,7 @@ def get_or_create(self, rse_id: str) -> "TN":
return rse_data

@property
def edges(self):
def edges(self) -> dict[tuple[TN, TN], TE]:
with self._lock:
return copy.copy(self._edges)

Expand All @@ -193,7 +195,7 @@ def get_or_create_edge(self, src_node: TN, dst_node: TN) -> "TE":
self._edges[src_node, dst_node] = edge = self._edge_cls(src_node, dst_node)
return edge

def delete_edge(self, src_node: TN, dst_node: TN):
def delete_edge(self, src_node: TN, dst_node: TN) -> None:
with self._lock:
edge = self._edges[src_node, dst_node]
edge.remove_from_nodes()
Expand All @@ -203,11 +205,11 @@ def multihop_enabled(self) -> bool:
return True if self._multihop_nodes else False

@read_session
def configure_multihop(self, multihop_rse_ids: Optional[set[str]] = None, *, session: "Session", logger: LoggerFunction = logging.log):
def configure_multihop(self, multihop_rse_ids: Optional[set[str]] = None, *, session: "Session", logger: "LoggerFunction" = logging.log) -> "Self":
with self._lock:
return self._configure_multihop(multihop_rse_ids=multihop_rse_ids, session=session, logger=logger)

def _configure_multihop(self, multihop_rse_ids: Optional[set[str]] = None, *, session: "Session", logger: LoggerFunction = logging.log):
def _configure_multihop(self, multihop_rse_ids: Optional[set[str]] = None, *, session: "Session", logger: "LoggerFunction" = logging.log) -> "Self":

if multihop_rse_ids is None:
multihop_rse_expression = config_get('transfers', 'multihop_rse_expression', default='available_for_multihop=true', expiration_time=600, session=session)
Expand Down Expand Up @@ -236,7 +238,7 @@ def _configure_multihop(self, multihop_rse_ids: Optional[set[str]] = None, *, se
return self

@read_session
def ensure_edges_loaded(self, *, session: "Session"):
def ensure_edges_loaded(self, *, session: "Session") -> None:
"""
Ensure that all edges are loaded for the (sub-)set of nodes known by this topology object
"""
Expand All @@ -246,7 +248,7 @@ def ensure_edges_loaded(self, *, session: "Session"):
with self._lock:
return self._ensure_edges_loaded(session=session)

def _ensure_edges_loaded(self, *, session: "Session"):
def _ensure_edges_loaded(self, *, session: "Session") -> None:
stmt = select(
models.Distance
).where(
Expand Down Expand Up @@ -281,7 +283,7 @@ def _ensure_edges_loaded(self, *, session: "Session"):
@read_session
def search_shortest_paths(
self,
src_nodes: list[TN],
src_nodes: Iterable[TN],
dst_node: TN,
operation_src: str,
operation_dest: str,
Expand Down Expand Up @@ -434,7 +436,7 @@ def __init__(self, ttl, new_obj_fnc):
self._new_obj_fnc = new_obj_fnc
self._ttl = ttl

def get(self, logger=logging.log):
def get(self, logger: "LoggerFunction" = logging.log) -> object:
with self._lock:
if not self._object \
or not self._creation_time \
Expand All @@ -452,7 +454,7 @@ def get_hops(
multihop_rse_ids: Optional[set[str]] = None,
limit_dest_schemes: Optional[list[str]] = None,
*, session: "Session",
):
) -> list["HopDict"]:
"""
Get a list of hops needed to transfer date from source_rse_id to dest_rse_id.
Ideally, the list will only include one item (dest_rse_id) since no hops are needed.
Expand Down