Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing; Add type annotations to lock.py #6562

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 24 additions & 16 deletions lib/rucio/core/lock.py
Expand Up @@ -14,16 +14,17 @@
# limitations under the License.

import logging
from collections.abc import Iterable, Iterator
from datetime import datetime
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional, Union

from sqlalchemy.exc import DatabaseError
from sqlalchemy.sql.expression import and_, or_

import rucio.core.did
import rucio.core.rule
from rucio.common.exception import DataIdentifierNotFound
from rucio.common.types import InternalScope
from rucio.common.types import InternalScope, LoggerFunction
from rucio.core.lifetime_exception import define_eol
from rucio.core.rse import get_rse_attribute, get_rse_name
from rucio.db.sqla import models, filter_thread_work
Expand All @@ -35,7 +36,7 @@


@stream_session
def get_dataset_locks(scope, name, *, session: "Session"):
def get_dataset_locks(scope: InternalScope, name: str, *, session: "Session") -> Iterator[dict[str, Any]]:
"""
Get the dataset locks of a dataset

Expand Down Expand Up @@ -69,7 +70,7 @@ def get_dataset_locks(scope, name, *, session: "Session"):


@stream_session
def get_dataset_locks_bulk(dids, *, session: "Session"):
def get_dataset_locks_bulk(dids: Iterable[dict[str, Any]], *, session: "Session") -> Iterator[dict[str, Any]]:
"""
Get the dataset locks of a list of datasets or containers, recursively

Expand Down Expand Up @@ -102,7 +103,7 @@ def get_dataset_locks_bulk(dids, *, session: "Session"):


@stream_session
def get_dataset_locks_by_rse_id(rse_id, *, session: "Session"):
def get_dataset_locks_by_rse_id(rse_id: str, *, session: "Session") -> Iterator[dict[str, Any]]:
"""
Get the dataset locks of an RSE.

Expand Down Expand Up @@ -135,7 +136,7 @@ def get_dataset_locks_by_rse_id(rse_id, *, session: "Session"):


@read_session
def get_replica_locks(scope, name, nowait=False, restrict_rses=None, *, session: "Session"):
def get_replica_locks(scope: InternalScope, name: str, nowait: bool = False, restrict_rses: Optional[Iterable[str]] = None, *, session: "Session") -> list[models.ReplicaLock]:
"""
Get the active replica locks for a file

Expand All @@ -160,7 +161,7 @@ def get_replica_locks(scope, name, nowait=False, restrict_rses=None, *, session:


@read_session
def get_replica_locks_for_rule_id(rule_id, *, session: "Session"):
def get_replica_locks_for_rule_id(rule_id: str, *, session: "Session") -> list[dict[str, Any]]:
"""
Get the active replica locks for a rule_id.

Expand All @@ -186,7 +187,7 @@ def get_replica_locks_for_rule_id(rule_id, *, session: "Session"):


@read_session
def get_replica_locks_for_rule_id_per_rse(rule_id, *, session: "Session"):
def get_replica_locks_for_rule_id_per_rse(rule_id: str, *, session: "Session") -> list[dict[str, str]]:
"""
Get the active replica locks for a rule_id per rse.

Expand All @@ -208,9 +209,9 @@ def get_replica_locks_for_rule_id_per_rse(rule_id, *, session: "Session"):


@read_session
def get_files_and_replica_locks_of_dataset(scope, name, nowait=False, restrict_rses=None, only_stuck=False,
total_threads=None, thread_id=None,
*, session: "Session"):
def get_files_and_replica_locks_of_dataset(scope: InternalScope, name: str, nowait: bool = False, restrict_rses: Optional[Iterable[str]] = None, only_stuck: bool = False,
total_threads: Optional[int] = None, thread_id: Optional[int] = None,
*, session: "Session") -> dict[tuple[InternalScope, str], Union[models.ReplicaLock, list[models.ReplicaLock]]]:
"""
Get all the files of a dataset and, if existing, all locks of the file.

Expand Down Expand Up @@ -316,7 +317,7 @@ def get_files_and_replica_locks_of_dataset(scope, name, nowait=False, restrict_r


@transactional_session
def successful_transfer(scope, name, rse_id, nowait, *, session: "Session", logger=logging.log):
def successful_transfer(scope: InternalScope, name: str, rse_id: str, nowait: bool, *, session: "Session", logger: LoggerFunction = logging.log) -> None:
"""
Update the state of all replica locks because of an successful transfer

Expand Down Expand Up @@ -385,7 +386,8 @@ def successful_transfer(scope, name, rse_id, nowait, *, session: "Session", logg


@transactional_session
def failed_transfer(scope, name, rse_id, error_message=None, broken_rule_id=None, broken_message=None, nowait=True, *, session: "Session", logger=logging.log):
def failed_transfer(scope: InternalScope, name: str, rse_id: str, error_message: Optional[str] = None, broken_rule_id: Optional[str] = None,
broken_message: Optional[str] = None, nowait: bool = True, *, session: "Session", logger: LoggerFunction = logging.log) -> None:
"""
Update the state of all replica locks because of a failed transfer.
If a transfer is permanently broken for a rule, the broken_rule_id should be filled which puts this rule into the SUSPENDED state.
Expand Down Expand Up @@ -428,7 +430,10 @@ def failed_transfer(scope, name, rse_id, error_message=None, broken_rule_id=None
pass
elif lock.rule_id == broken_rule_id:
rule.state = RuleState.SUSPENDED
rule.error = (broken_message[:245] + '...') if len(broken_message) > 245 else broken_message
if broken_message is not None and len(broken_message) > 245:
rule.error = (broken_message[:245] + '...')
else:
rule.error = broken_message
# Try to update the DatasetLocks
if rule.grouping != RuleGrouping.NONE:
ds_locks = session.query(models.DatasetLock).with_for_update(nowait=nowait).filter_by(rule_id=rule.id)
Expand All @@ -443,14 +448,17 @@ def failed_transfer(scope, name, rse_id, error_message=None, broken_rule_id=None
for ds_lock in ds_locks:
ds_lock.state = LockState.STUCK
if rule.error != error_message:
rule.error = (error_message[:245] + '...') if len(error_message) > 245 else error_message
if error_message is not None and len(error_message) > 245:
rule.error = (error_message[:245] + '...')
else:
rule.error = error_message

# Insert rule history
rucio.core.rule.insert_rule_history(rule=rule, recent=True, longterm=False, session=session)


@transactional_session
def touch_dataset_locks(dataset_locks, *, session: "Session"):
def touch_dataset_locks(dataset_locks: Iterable[dict[str, Any]], *, session: "Session") -> bool:
"""
Update the accessed_at timestamp of the given dataset locks + eol_at.

Expand Down