Skip to content

Commit

Permalink
Typing; Add type annotations to lock.py; #6454
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio committed Mar 18, 2024
1 parent 759c19a commit 746072b
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions lib/rucio/core/lock.py
Expand Up @@ -14,16 +14,17 @@
# limitations under the License.

import logging
from collections.abc import Iterable, Iterator
from datetime import datetime
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional, Union

from sqlalchemy.exc import DatabaseError
from sqlalchemy.sql.expression import and_, or_

import rucio.core.did
import rucio.core.rule
from rucio.common.exception import DataIdentifierNotFound
from rucio.common.types import InternalScope
from rucio.common.types import InternalScope, LoggerFunction
from rucio.core.lifetime_exception import define_eol
from rucio.core.rse import get_rse_attribute, get_rse_name
from rucio.db.sqla import models, filter_thread_work
Expand All @@ -35,7 +36,7 @@


@stream_session
def get_dataset_locks(scope, name, *, session: "Session"):
def get_dataset_locks(scope: InternalScope, name: str, *, session: "Session") -> Iterator[dict[str, Any]]:
"""
Get the dataset locks of a dataset
Expand Down Expand Up @@ -69,7 +70,7 @@ def get_dataset_locks(scope, name, *, session: "Session"):


@stream_session
def get_dataset_locks_bulk(dids, *, session: "Session"):
def get_dataset_locks_bulk(dids: Iterable[dict[str, Any]], *, session: "Session") -> Iterator[dict[str, Any]]:
"""
Get the dataset locks of a list of datasets or containers, recursively
Expand Down Expand Up @@ -102,7 +103,7 @@ def get_dataset_locks_bulk(dids, *, session: "Session"):


@stream_session
def get_dataset_locks_by_rse_id(rse_id, *, session: "Session"):
def get_dataset_locks_by_rse_id(rse_id: str, *, session: "Session") -> Iterator[dict[str, Any]]:
"""
Get the dataset locks of an RSE.
Expand Down Expand Up @@ -135,7 +136,7 @@ def get_dataset_locks_by_rse_id(rse_id, *, session: "Session"):


@read_session
def get_replica_locks(scope, name, nowait=False, restrict_rses=None, *, session: "Session"):
def get_replica_locks(scope: InternalScope, name: str, nowait: bool = False, restrict_rses: Optional[Iterable[str]] = None, *, session: "Session") -> list[models.ReplicaLock]:
"""
Get the active replica locks for a file
Expand All @@ -160,7 +161,7 @@ def get_replica_locks(scope, name, nowait=False, restrict_rses=None, *, session:


@read_session
def get_replica_locks_for_rule_id(rule_id, *, session: "Session"):
def get_replica_locks_for_rule_id(rule_id: str, *, session: "Session") -> list[dict[str, Any]]:
"""
Get the active replica locks for a rule_id.
Expand All @@ -186,7 +187,7 @@ def get_replica_locks_for_rule_id(rule_id, *, session: "Session"):


@read_session
def get_replica_locks_for_rule_id_per_rse(rule_id, *, session: "Session"):
def get_replica_locks_for_rule_id_per_rse(rule_id: str, *, session: "Session") -> list[dict[str, str]]:
"""
Get the active replica locks for a rule_id per rse.
Expand All @@ -208,9 +209,9 @@ def get_replica_locks_for_rule_id_per_rse(rule_id, *, session: "Session"):


@read_session
def get_files_and_replica_locks_of_dataset(scope, name, nowait=False, restrict_rses=None, only_stuck=False,
total_threads=None, thread_id=None,
*, session: "Session"):
def get_files_and_replica_locks_of_dataset(scope: InternalScope, name: str, nowait: bool = False, restrict_rses: Optional[Iterable[str]] = None, only_stuck: bool = False,
total_threads: Optional[int] = None, thread_id: Optional[int] = None,
*, session: "Session") -> dict[tuple[InternalScope, str], Union[models.ReplicaLock, list[models.ReplicaLock]]]:
"""
Get all the files of a dataset and, if existing, all locks of the file.
Expand Down Expand Up @@ -316,7 +317,7 @@ def get_files_and_replica_locks_of_dataset(scope, name, nowait=False, restrict_r


@transactional_session
def successful_transfer(scope, name, rse_id, nowait, *, session: "Session", logger=logging.log):
def successful_transfer(scope: InternalScope, name: str, rse_id: str, nowait: bool, *, session: "Session", logger: LoggerFunction = logging.log) -> None:
"""
Update the state of all replica locks because of an successful transfer
Expand Down Expand Up @@ -385,7 +386,8 @@ def successful_transfer(scope, name, rse_id, nowait, *, session: "Session", logg


@transactional_session
def failed_transfer(scope, name, rse_id, error_message=None, broken_rule_id=None, broken_message=None, nowait=True, *, session: "Session", logger=logging.log):
def failed_transfer(scope: InternalScope, name: str, rse_id: str, error_message: Optional[str] = None, broken_rule_id: Optional[str] = None,
broken_message: Optional[str] = None, nowait: bool = True, *, session: "Session", logger: LoggerFunction = logging.log) -> None:
"""
Update the state of all replica locks because of a failed transfer.
If a transfer is permanently broken for a rule, the broken_rule_id should be filled which puts this rule into the SUSPENDED state.
Expand Down Expand Up @@ -428,7 +430,10 @@ def failed_transfer(scope, name, rse_id, error_message=None, broken_rule_id=None
pass
elif lock.rule_id == broken_rule_id:
rule.state = RuleState.SUSPENDED
rule.error = (broken_message[:245] + '...') if len(broken_message) > 245 else broken_message
if broken_message is not None and len(broken_message) > 245:
rule.error = (broken_message[:245] + '...')
else:
rule.error = broken_message
# Try to update the DatasetLocks
if rule.grouping != RuleGrouping.NONE:
ds_locks = session.query(models.DatasetLock).with_for_update(nowait=nowait).filter_by(rule_id=rule.id)
Expand All @@ -443,14 +448,17 @@ def failed_transfer(scope, name, rse_id, error_message=None, broken_rule_id=None
for ds_lock in ds_locks:
ds_lock.state = LockState.STUCK
if rule.error != error_message:
rule.error = (error_message[:245] + '...') if len(error_message) > 245 else error_message
if error_message is not None and len(error_message) > 245:
rule.error = (error_message[:245] + '...')
else:
rule.error = error_message

# Insert rule history
rucio.core.rule.insert_rule_history(rule=rule, recent=True, longterm=False, session=session)


@transactional_session
def touch_dataset_locks(dataset_locks, *, session: "Session"):
def touch_dataset_locks(dataset_locks: Iterable[dict[str, Any]], *, session: "Session") -> bool:
"""
Update the accessed_at timestamp of the given dataset locks + eol_at.
Expand Down

0 comments on commit 746072b

Please sign in to comment.