Skip to content

Commit

Permalink
Rules: Add RuleDict typed dict; #6454
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio committed Feb 29, 2024
1 parent 9e1d03e commit 511512c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
17 changes: 16 additions & 1 deletion lib/rucio/common/types.py
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Optional, TypedDict, Union
from typing import Any, Callable, Literal, Optional, TypedDict, Union


class InternalType(object):
Expand Down Expand Up @@ -159,3 +159,18 @@ class RSESettingsDict(TypedDict):
deterministic: bool
domain: list[str]
protocols: list[RSEProtocolDict]


class RuleDict(TypedDict):
account: InternalAccount
copies: int
rse_expression: str
grouping: Literal['ALL', 'DATASET', 'NONE']
weight: str
lifetime: int
locked: bool
subscription_id: str
source_replica_expression: Optional[str]
activity: str
notify: Optional[Literal['Y', 'N', 'C']]
purge_replicas: bool
20 changes: 12 additions & 8 deletions lib/rucio/core/rule.py
Expand Up @@ -22,7 +22,7 @@
from os import path
from re import match
from string import Template
from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, Sequence
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Type, TypeVar, Sequence

from dogpile.cache.api import NO_VALUE
from sqlalchemy import select, update
Expand All @@ -45,7 +45,7 @@
InvalidSourceReplicaExpression)
from rucio.common.policy import policy_filter, get_scratchdisk_lifetime
from rucio.common.schema import validate_schema
from rucio.common.types import InternalScope, InternalAccount, LoggerFunction
from rucio.common.types import InternalScope, InternalAccount, LoggerFunction, RuleDict
from rucio.common.utils import str_to_date, sizefmt, chunks
from rucio.core import account_counter, rse_counter, request as request_core, transfer as transfer_core
from rucio.core.account import get_account
Expand Down Expand Up @@ -143,8 +143,8 @@ def default(rule: models.ReplicationRule, did: models.DataIdentifier, session: '


@transactional_session
def add_rule(dids: Sequence[dict[str, Any]], account: InternalAccount, copies: int, rse_expression: str, weight: str, lifetime: int, locked: bool, subscription_id: str,
grouping: RuleGrouping = RuleGrouping.DATASET, source_replica_expression: Optional[str] = None, activity: str = 'User Subscriptions', notify: RuleNotification = RuleNotification.NO, purge_replicas: bool = False,
def add_rule(dids: Sequence[dict[str, Any]], account: InternalAccount, copies: int, rse_expression: str, grouping: Literal['ALL', 'DATASET', 'NONE'], weight: str, lifetime: int, locked: bool, subscription_id: str,
source_replica_expression: Optional[str] = None, activity: str = 'User Subscriptions', notify: Optional[Literal['Y', 'N', 'C']] = None, purge_replicas: bool = False,
ignore_availability: bool = False, comment: Optional[str] = None, ask_approval: bool = False, asynchronous: bool = False, ignore_account_limit: bool = False,
priority: int = 3, delay_injection: Optional[int] = None, split_container: bool = False, meta: Optional[dict[str, Any]] = None, *, session: "Session", logger: LoggerFunction = logging.log) -> list[str]:
"""
Expand All @@ -163,7 +163,7 @@ def add_rule(dids: Sequence[dict[str, Any]], account: InternalAccount, copies: i
:param subscription_id: The subscription_id, if the rule is created by a subscription.
:param source_replica_expression: Only use replicas as source from this RSEs.
:param activity: Activity to be passed on to the conveyor.
:param notify: Notification setting of the rule.
:param notify: Notification setting of the rule ('Y', 'N', 'C'; None = 'N').
:param purge_replicas: Purge setting if a replica should be directly deleted after the rule is deleted.
:param ignore_availability: Option to ignore the availability of RSEs.
:param comment: Comment about the rule.
Expand All @@ -185,6 +185,8 @@ def add_rule(dids: Sequence[dict[str, Any]], account: InternalAccount, copies: i

rule_ids = []

grouping_value = {'ALL': RuleGrouping.ALL, 'NONE': RuleGrouping.NONE}.get(grouping, RuleGrouping.DATASET)

with METRICS.timer('add_rule.total'):
# 1. Resolve the rse_expression into a list of RSE-ids
with METRICS.timer('add_rule.parse_rse_expression'):
Expand Down Expand Up @@ -229,6 +231,8 @@ def add_rule(dids: Sequence[dict[str, Any]], account: InternalAccount, copies: i

expires_at = datetime.utcnow() + timedelta(seconds=lifetime) if lifetime is not None else None

notify_value = {'Y': RuleNotification.YES, 'C': RuleNotification.CLOSE, 'P': RuleNotification.PROGRESS}.get(str(notify or ''), RuleNotification.NO)

for elem in dids:
# 3. Get the did
with METRICS.timer('add_rule.get_did'):
Expand Down Expand Up @@ -286,13 +290,13 @@ def add_rule(dids: Sequence[dict[str, Any]], account: InternalAccount, copies: i
copies=copies,
rse_expression=rse_expression,
locked=locked,
grouping=grouping,
grouping=grouping_value,
expires_at=expires_at,
weight=weight,
source_replica_expression=source_replica_expression,
activity=activity,
subscription_id=subscription_id,
notification=notify,
notification=notify_value,
purge_replicas=purge_replicas,
ignore_availability=ignore_availability,
comments=comment,
Expand Down Expand Up @@ -383,7 +387,7 @@ def add_rule(dids: Sequence[dict[str, Any]], account: InternalAccount, copies: i


@transactional_session
def add_rules(dids: Sequence[dict[str, Any]], rules: Sequence[models.ReplicationRule], *, session: "Session", logger: LoggerFunction = logging.log) -> dict[tuple[InternalScope, str], list[str]]:
def add_rules(dids: Sequence[dict[str, Any]], rules: Sequence[RuleDict], *, session: "Session", logger: LoggerFunction = logging.log) -> dict[tuple[InternalScope, str], list[str]]:
"""
Adds a list of replication rules to every did in dids
Expand Down

0 comments on commit 511512c

Please sign in to comment.