Skip to content

Commit

Permalink
Typing: Fix type annotations in rule.py; #6454
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio committed Mar 13, 2024
1 parent 1edbe2f commit ea4b312
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 32 deletions.
8 changes: 4 additions & 4 deletions lib/rucio/common/types.py
Expand Up @@ -166,11 +166,11 @@ class RuleDict(TypedDict):
copies: int
rse_expression: str
grouping: Literal['ALL', 'DATASET', 'NONE']
weight: str
lifetime: int
weight: Optional[str]
lifetime: Optional[int]
locked: bool
subscription_id: str
subscription_id: Optional[str]
source_replica_expression: Optional[str]
activity: str
notify: Optional[Literal['Y', 'N', 'C']]
notify: Optional[Literal['Y', 'N', 'C', 'P']]
purge_replicas: bool
57 changes: 33 additions & 24 deletions lib/rucio/core/rule.py
Expand Up @@ -148,13 +148,13 @@ def add_rule(dids: Sequence[dict[str, Any]],
copies: int,
rse_expression: str,
grouping: Literal['ALL', 'DATASET', 'NONE'],
weight: str,
lifetime: int,
weight: Optional[str],
lifetime: Optional[int],
locked: bool,
subscription_id: str,
subscription_id: Optional[str],
source_replica_expression: Optional[str] = None,
activity: str = 'User Subscriptions',
notify: Optional[Literal['Y', 'N', 'C']] = None,
notify: Optional[Literal['Y', 'N', 'C', 'P']] = None,
purge_replicas: bool = False,
ignore_availability: bool = False,
comment: Optional[str] = None,
Expand Down Expand Up @@ -253,7 +253,7 @@ def add_rule(dids: Sequence[dict[str, Any]],

expires_at = datetime.utcnow() + timedelta(seconds=lifetime) if lifetime is not None else None

notify_value = {'Y': RuleNotification.YES, 'C': RuleNotification.CLOSE, 'P': RuleNotification.PROGRESS}.get(str(notify or ''), RuleNotification.NO)
notify_value = {'Y': RuleNotification.YES, 'C': RuleNotification.CLOSE, 'P': RuleNotification.PROGRESS}.get(notify or '', RuleNotification.NO)

for elem in dids:
# 3. Get the did
Expand Down Expand Up @@ -548,7 +548,8 @@ def add_rules(dids: Sequence[dict[str, Any]],
with METRICS.timer('add_rules.create_rule'):
grouping = {'ALL': RuleGrouping.ALL, 'NONE': RuleGrouping.NONE}.get(str(rule.get('grouping')), RuleGrouping.DATASET)

expires_at: Optional[datetime] = datetime.utcnow() + timedelta(seconds=rule.get('lifetime')) if rule.get('lifetime') is not None else None
rule_lifetime: Optional[int] = rule.get('lifetime')
expires_at: Optional[datetime] = datetime.utcnow() + timedelta(seconds=rule_lifetime) if rule_lifetime is not None else None

notify = {'Y': RuleNotification.YES, 'C': RuleNotification.CLOSE, 'P': RuleNotification.PROGRESS, None: RuleNotification.NO}.get(rule.get('notify'))

Expand Down Expand Up @@ -1747,7 +1748,7 @@ def re_evaluate_did(scope: InternalScope, name: str, rule_evaluation_action: DID


@read_session
def get_updated_dids(total_workers: int, worker_number: int, limit: int = 100, blocked_dids: Iterable[tuple[str, str]] = [], *, session: "Session") -> list[tuple[str, str]]:
def get_updated_dids(total_workers: int, worker_number: int, limit: int = 100, blocked_dids: Iterable[tuple[str, str]] = [], *, session: "Session") -> list[tuple[str, InternalScope, str, DIDReEvaluation]]:
"""
Get updated dids.
Expand Down Expand Up @@ -1785,7 +1786,15 @@ def get_updated_dids(total_workers: int, worker_number: int, limit: int = 100, b


@read_session
def get_rules_beyond_eol(date_check: datetime, worker_number: int, total_workers: int, *, session: "Session") -> list[dict[str, Any]]:
def get_rules_beyond_eol(date_check: datetime, worker_number: int, total_workers: int, *, session: "Session"
) -> list[tuple[InternalScope,
str,
str,
bool,
str,
Optional[datetime],
Optional[datetime],
InternalAccount]]:
"""
Get rules which have eol_at before a certain date.
Expand Down Expand Up @@ -1856,7 +1865,7 @@ def get_injected_rules(total_workers: int,
blocked_rules: Sequence[str] = [],
*,
session: "Session"
) -> list[tuple[str, str]]:
) -> list[str]:
"""
Get rules to be injected.
Expand Down Expand Up @@ -1898,7 +1907,7 @@ def get_stuck_rules(total_workers: int,
blocked_rules: Sequence[str] = [],
*,
session: "Session"
) -> list[tuple[str, str]]:
) -> list[str]:
"""
Get stuck rules.
Expand Down Expand Up @@ -2601,9 +2610,9 @@ def list_rules_for_rse_decommissioning(

@transactional_session
def __find_missing_locks_and_create_them(datasetfiles: Sequence[dict[str, Any]],
locks: dict[tuple[str, str], Sequence[models.ReplicaLock]],
replicas: dict[tuple[str, str], Any],
source_replicas: dict[tuple[str, str], Any],
locks: dict[tuple[InternalScope, str], Sequence[models.ReplicaLock]],
replicas: dict[tuple[InternalScope, str], Any],
source_replicas: dict[tuple[InternalScope, str], Any],
rseselector: RSESelector,
rule: models.ReplicationRule,
source_rses: Sequence[str],
Expand Down Expand Up @@ -2658,7 +2667,7 @@ def __find_missing_locks_and_create_them(datasetfiles: Sequence[dict[str, Any]],

@transactional_session
def __find_surplus_locks_and_remove_them(datasetfiles: Sequence[dict[str, Any]],
locks: dict[tuple[str, str], list[models.ReplicaLock]],
locks: dict[tuple[InternalScope, str], list[models.ReplicaLock]],
rule: models.ReplicationRule,
*,
session: "Session",
Expand Down Expand Up @@ -2708,9 +2717,9 @@ def __find_surplus_locks_and_remove_them(datasetfiles: Sequence[dict[str, Any]],

@transactional_session
def __find_stuck_locks_and_repair_them(datasetfiles: Sequence[dict[str, Any]],
locks: dict[tuple[str, str], Sequence[models.ReplicaLock]],
replicas: dict[tuple[str, str], Any],
source_replicas: dict[tuple[str, str], Any],
locks: dict[tuple[InternalScope, str], Sequence[models.ReplicaLock]],
replicas: dict[tuple[InternalScope, str], Any],
source_replicas: dict[tuple[InternalScope, str], Any],
rseselector: RSESelector,
rule: models.ReplicationRule,
source_rses: Sequence[str],
Expand Down Expand Up @@ -3091,7 +3100,7 @@ def __resolve_did_to_locks_and_replicas(did: models.DataIdentifier,
session: "Session"
) -> tuple[list[dict[str, Any]],
dict[tuple[str, str], models.ReplicaLock],
dict[tuple[str, str], Any],
dict[tuple[str, str], models.RSEFileAssociation],
dict[tuple[str, str], str]]:
"""
Resolves a did to its constituent childs and reads the locks and replicas of all the constituent files.
Expand Down Expand Up @@ -3197,7 +3206,7 @@ def __resolve_dids_to_locks_and_replicas(dids: Sequence[models.DataIdentifierAss
session: "Session"
) -> tuple[list[dict[str, Any]],
dict[tuple[str, str], models.ReplicaLock],
dict[tuple[str, str], Any],
dict[tuple[str, str], models.RSEFileAssociation],
dict[tuple[str, str], str]]:
"""
Resolves a list of dids to its constituent childs and reads the locks and replicas of all the constituent files.
Expand Down Expand Up @@ -3311,9 +3320,9 @@ def __resolve_dids_to_locks_and_replicas(dids: Sequence[models.DataIdentifierAss

@transactional_session
def __create_locks_replicas_transfers(datasetfiles: Sequence[dict[str, Any]],
locks: dict[tuple[str, str], Sequence[models.ReplicaLock]],
replicas: dict[tuple[str, str], Any],
source_replicas: dict[tuple[str, str], Any],
locks: dict[tuple[InternalScope, str], Sequence[models.ReplicaLock]],
replicas: dict[tuple[InternalScope, str], Any],
source_replicas: dict[tuple[InternalScope, str], Any],
rseselector: RSESelector,
rule: models.ReplicationRule,
preferred_rse_ids: Sequence[str] = [],
Expand Down Expand Up @@ -3509,7 +3518,7 @@ def _create_recipients_list(rse_expression: str,
filter_: Optional[str] = None,
*,
session: "Session"
) -> list[tuple[str, InternalAccount]]:
) -> list[tuple[str, Union[str, InternalAccount]]]:
"""
Create a list of recipients for a notification email based on rse_expression.
Expand Down Expand Up @@ -3675,7 +3684,7 @@ def archive_localgroupdisk_datasets(scope: InternalScope,

@policy_filter
@read_session
def get_scratch_policy(account: str, rses: Sequence[dict[str, Any]], lifetime: int, *, session: "Session") -> int:
def get_scratch_policy(account: InternalAccount, rses: Sequence[dict[str, Any]], lifetime: Optional[int], *, session: "Session") -> Optional[int]:
"""
ATLAS policy for rules on SCRATCHDISK
Expand Down
9 changes: 5 additions & 4 deletions lib/rucio/core/rule_grouping.py
Expand Up @@ -26,6 +26,7 @@
import rucio.core.replica
from rucio.common.config import config_get_int
from rucio.common.exception import InsufficientTargetRSEs
from rucio.common.types import InternalScope
from rucio.core import account_counter, rse_counter, request as request_core
from rucio.core.rse_selector import RSESelector
from rucio.core.rse import get_rse, get_rse_attribute, get_rse_name
Expand All @@ -38,8 +39,8 @@


@transactional_session
def apply_rule_grouping(datasetfiles: Sequence[dict[str, Any]], locks: dict[tuple[str, str], models.ReplicaLock],
replicas: dict[tuple[str, str], Any], source_replicas: dict[tuple[str, str], Any],
def apply_rule_grouping(datasetfiles: Sequence[dict[str, Any]], locks: dict[tuple[InternalScope, str], models.ReplicaLock],
replicas: dict[tuple[InternalScope, str], Any], source_replicas: dict[tuple[InternalScope, str], Any],
rseselector: RSESelector, rule: models.ReplicationRule, preferred_rse_ids: Sequence[str] = [],
source_rses: Sequence[str] = [], *,
session: "Session") -> tuple[dict[str, list[dict[str, models.RSEFileAssociation]]],
Expand Down Expand Up @@ -104,8 +105,8 @@ def apply_rule_grouping(datasetfiles: Sequence[dict[str, Any]], locks: dict[tupl


@transactional_session
def repair_stuck_locks_and_apply_rule_grouping(datasetfiles: Sequence[dict[str, Any]], locks: dict[tuple[str, str], models.ReplicaLock],
replicas: dict[tuple[str, str], Any], source_replicas: dict[tuple[str, str], Any],
def repair_stuck_locks_and_apply_rule_grouping(datasetfiles: Sequence[dict[str, Any]], locks: dict[tuple[InternalScope, str], models.ReplicaLock],
replicas: dict[tuple[InternalScope, str], Any], source_replicas: dict[tuple[InternalScope, str], Any],
rseselector: RSESelector, rule: models.ReplicationRule, source_rses: Sequence[str], *,
session: "Session") -> tuple[dict[str, list[dict[str, models.RSEFileAssociation]]],
dict[str, list[dict[str, models.ReplicaLock]]],
Expand Down

0 comments on commit ea4b312

Please sign in to comment.