Skip to content

Commit

Permalink
Enhancement: Removing mutable params, adding type hints; #5169
Browse files Browse the repository at this point in the history
  • Loading branch information
erlingstaff authored and bari12 committed May 10, 2024
1 parent b672ffb commit 1b631a7
Show file tree
Hide file tree
Showing 22 changed files with 339 additions and 101 deletions.
8 changes: 4 additions & 4 deletions lib/rucio/api/account.py
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Optional

import rucio.api.permission
import rucio.common.exception
Expand Down Expand Up @@ -115,7 +116,7 @@ def update_account(account, key, value, issuer='root', vo='def', *, session: "Se


@stream_session
def list_accounts(filter_={}, vo='def', *, session: "Session"):
def list_accounts(filter_: Optional[dict[str, Any]] = None, vo: str = 'def', *, session: "Session") -> Iterator[dict[str, Any]]:
"""
Lists all the Rucio account names.
Expand All @@ -128,8 +129,7 @@ def list_accounts(filter_={}, vo='def', *, session: "Session"):
:returns: List of all accounts.
"""
# If filter is empty, create a new dict to avoid overwriting the function's default
if not filter_:
filter_ = {}
filter_ = filter_ or {}

if 'account' in filter_:
filter_['account'] = InternalAccount(filter_['account'], vo=vo)
Expand Down
40 changes: 29 additions & 11 deletions lib/rucio/api/did.py
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Iterator
from collections.abc import Iterator, Sequence
from copy import deepcopy
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import rucio.api.permission
from rucio.common.constants import RESERVED_KEYS
Expand All @@ -29,7 +29,7 @@
from rucio.db.sqla.session import read_session, stream_session, transactional_session

if TYPE_CHECKING:
from typing import Any, Optional
from typing import Optional

from sqlalchemy.orm import Session

Expand Down Expand Up @@ -67,7 +67,22 @@ def list_dids(scope, filters, did_type='collection', ignore_case=False, limit=No


@transactional_session
def add_did(scope, name, did_type, issuer, account=None, statuses={}, meta={}, rules=[], lifetime=None, dids=[], rse=None, vo='def', *, session: "Session"):
def add_did(
scope: str,
name: str,
did_type: str,
issuer: str,
account: "Optional[str]" = None,
statuses: "Optional[dict[str, str]]" = None,
meta: "Optional[dict[str, str]]" = None,
rules: "Optional[Sequence[dict[str, Any]]]" = None,
lifetime: "Optional[str]" = None,
dids: "Optional[Sequence[dict[str, Any]]]" = None,
rse: "Optional[str]" = None,
vo: str = 'def',
*,
session: "Session"
) -> None:
"""
Add data did.
Expand All @@ -85,6 +100,10 @@ def add_did(scope, name, did_type, issuer, account=None, statuses={}, meta={}, r
:param vo: The VO to act on.
:param session: The database session in use.
"""
statuses = statuses or {}
meta = meta or {}
rules = rules or []
dids = dids or []
v_did = {'name': name, 'type': did_type.upper(), 'scope': scope}
validate_schema(name='did', obj=v_did, vo=vo)
validate_schema(name='dids', obj=dids, vo=vo)
Expand All @@ -93,10 +112,9 @@ def add_did(scope, name, did_type, issuer, account=None, statuses={}, meta={}, r
if not rucio.api.permission.has_permission(issuer=issuer, vo=vo, action='add_did', kwargs=kwargs, session=session):
raise rucio.common.exception.AccessDenied('Account %s can not add data identifier to scope %s' % (issuer, scope))

if account is not None:
account = InternalAccount(account, vo=vo)
issuer = InternalAccount(issuer, vo=vo)
scope = InternalScope(scope, vo=vo)
owner_account = None if account is None else InternalAccount(account, vo=vo)
issuer_account = InternalAccount(issuer, vo=vo)
internal_scope = InternalScope(scope, vo=vo)
for d in dids:
d['scope'] = InternalScope(d['scope'], vo=vo)
for r in rules:
Expand All @@ -108,7 +126,7 @@ def add_did(scope, name, did_type, issuer, account=None, statuses={}, meta={}, r

if did_type == 'DATASET':
# naming_convention validation
extra_meta = naming_convention.validate_name(scope=scope, name=name, did_type='D', session=session)
extra_meta = naming_convention.validate_name(scope=internal_scope, name=name, did_type='D', session=session)

# merge extra_meta with meta
for k in extra_meta or {}:
Expand All @@ -121,7 +139,7 @@ def add_did(scope, name, did_type, issuer, account=None, statuses={}, meta={}, r
# Validate metadata
meta_convention_core.validate_meta(meta=meta, did_type=DIDType[did_type.upper()], session=session)

return did.add_did(scope=scope, name=name, did_type=DIDType[did_type.upper()], account=account or issuer,
return did.add_did(scope=internal_scope, name=name, did_type=DIDType[did_type.upper()], account=owner_account or issuer_account,
statuses=statuses, meta=meta, rules=rules, lifetime=lifetime,
dids=dids, rse_id=rse_id, session=session)

Expand Down Expand Up @@ -330,7 +348,7 @@ def list_content_history(scope, name, vo='def', *, session: "Session"):


@stream_session
def bulk_list_files(dids: "list[dict[str, Any]]", long: bool = False, vo: str = 'def', *, session: "Session") -> "Iterator[dict[str, Any]]":
def bulk_list_files(dids: list[dict[str, Any]], long: bool = False, vo: str = 'def', *, session: "Session") -> Iterator[dict[str, Any]]:
"""
List file contents of a list of data identifiers.
Expand Down
6 changes: 4 additions & 2 deletions lib/rucio/api/replica.py
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

import datetime
from typing import TYPE_CHECKING
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Optional

from rucio.api import permission
from rucio.common import exception
Expand Down Expand Up @@ -418,7 +419,7 @@ def list_dataset_replicas_vp(scope, name, deep=False, vo='def', *, session: "Ses


@stream_session
def list_datasets_per_rse(rse, filters={}, limit=None, vo='def', *, session: "Session"):
def list_datasets_per_rse(rse: str, filters: Optional[dict[str, Any]] = None, limit: Optional[int] = None, vo: str = 'def', *, session: "Session") -> Iterator[dict[str, Any]]:
"""
:param scope: The scope of the dataset.
:param name: The name of the dataset.
Expand All @@ -430,6 +431,7 @@ def list_datasets_per_rse(rse, filters={}, limit=None, vo='def', *, session: "Se
:returns: A list of dict dataset replicas
"""

filters = filters or {}
rse_id = get_rse_id(rse=rse, vo=vo, session=session)
if 'scope' in filters:
filters['scope'] = InternalScope(filters['scope'], vo=vo)
Expand Down
7 changes: 3 additions & 4 deletions lib/rucio/api/rse.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from rucio.api import permission
from rucio.common import exception
Expand Down Expand Up @@ -109,7 +109,7 @@ def del_rse(rse, issuer, vo='def', *, session: "Session"):


@read_session
def list_rses(filters={}, vo='def', *, session: "Session"):
def list_rses(filters: "Optional[dict[str, Any]]" = None, vo: str = 'def', *, session: "Session") -> list[dict[str, Any]]:
"""
Lists all RSEs.
Expand All @@ -119,8 +119,7 @@ def list_rses(filters={}, vo='def', *, session: "Session"):
:returns: List of all RSEs.
"""
if not filters:
filters = {}
filters = filters or {}

filters['vo'] = vo

Expand Down
8 changes: 4 additions & 4 deletions lib/rucio/api/rule.py
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Optional

from rucio.api.permission import has_permission
from rucio.common.config import config_get_bool
Expand Down Expand Up @@ -134,7 +135,7 @@ def get_replication_rule(rule_id, issuer, vo='def', *, session: "Session"):


@stream_session
def list_replication_rules(filters={}, vo='def', *, session: "Session"):
def list_replication_rules(filters: Optional[dict[str, Any]] = None, vo: str = 'def', *, session: "Session") -> Iterator[dict[str, Any]]:
"""
Lists replication rules based on a filter.
Expand All @@ -143,8 +144,7 @@ def list_replication_rules(filters={}, vo='def', *, session: "Session"):
:param session: The database session in use.
"""
# If filters is empty, create a new dict to avoid overwriting the function's default
if not filters:
filters = {}
filters = filters or {}

if 'scope' in filters:
scope = filters['scope']
Expand Down
7 changes: 3 additions & 4 deletions lib/rucio/api/scope.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional

import rucio.api.permission
import rucio.common.exception
Expand All @@ -26,7 +26,7 @@


@read_session
def list_scopes(filter_={}, vo='def', *, session: "Session"):
def list_scopes(filter_: Optional[dict[str, Any]] = None, vo: str = 'def', *, session: "Session") -> list[str]:
"""
Lists all scopes.
Expand All @@ -37,8 +37,7 @@ def list_scopes(filter_={}, vo='def', *, session: "Session"):
:returns: A list containing all scopes.
"""
# If filter is empty, create a new dict to avoid overwriting the function's default
if not filter_:
filter_ = {}
filter_ = filter_ or {}

if 'scope' in filter_:
filter_['scope'] = InternalScope(scope=filter_['scope'], vo=vo)
Expand Down
63 changes: 56 additions & 7 deletions lib/rucio/client/downloadclient.py
Expand Up @@ -25,6 +25,7 @@
import time
from queue import Empty, Queue, deque
from threading import Thread
from typing import Any, Optional

from rucio import version
from rucio.client.client import Client
Expand Down Expand Up @@ -174,7 +175,14 @@ def __init__(self, client=None, logger=None, tracing=True, check_admin=False, ch
self.extraction_tools.append(BaseExtractionTool('tar', '--version', extract_args, logger=self.logger))
self.extract_scope_convention = config_get('common', 'extract_scope', False, None)

def download_pfns(self, items, num_threads=2, trace_custom_fields={}, traces_copy_out=None, deactivate_file_download_exceptions=False):
def download_pfns(
self,
items: list[dict[str, Any]],
num_threads: int = 2,
trace_custom_fields: Optional[dict[str, Any]] = None,
traces_copy_out: Optional[list[dict[str, Any]]] = None,
deactivate_file_download_exceptions: bool = False
) -> list[dict[str, Any]]:
"""
Download items with a given PFN. This function can only download files, no datasets.
Expand Down Expand Up @@ -202,6 +210,7 @@ def download_pfns(self, items, num_threads=2, trace_custom_fields={}, traces_cop
:raises NotAllFilesDownloaded: if not all files could be downloaded
:raises RucioException: if something unexpected went wrong during the download
"""
trace_custom_fields = trace_custom_fields or {}
logger = self.logger
trace_custom_fields['uuid'] = generate_uuid()

Expand Down Expand Up @@ -250,8 +259,15 @@ def download_pfns(self, items, num_threads=2, trace_custom_fields={}, traces_cop

return self._check_output(output_items, deactivate_file_download_exceptions=deactivate_file_download_exceptions)

def download_dids(self, items, num_threads=2, trace_custom_fields={}, traces_copy_out=None,
deactivate_file_download_exceptions=False, sort=None):
def download_dids(
self,
items: list[dict[str, Any]],
num_threads: int = 2,
trace_custom_fields: Optional[dict[str, Any]] = None,
traces_copy_out: Optional[list[dict[str, Any]]] = None,
deactivate_file_download_exceptions: bool = False,
sort: Optional[str] = None
) -> list[dict[str, Any]]:
"""
Download items with given DIDs. This function can also download datasets and wildcarded DIDs.
Expand Down Expand Up @@ -286,6 +302,7 @@ def download_dids(self, items, num_threads=2, trace_custom_fields={}, traces_cop
:raises NotAllFilesDownloaded: if not all files could be downloaded
:raises RucioException: if something unexpected went wrong during the download
"""
trace_custom_fields = trace_custom_fields or {}
logger = self.logger
trace_custom_fields['uuid'] = generate_uuid()

Expand All @@ -304,7 +321,15 @@ def download_dids(self, items, num_threads=2, trace_custom_fields={}, traces_cop

return self._check_output(output_items, deactivate_file_download_exceptions=deactivate_file_download_exceptions)

def download_from_metalink_file(self, item, metalink_file_path, num_threads=2, trace_custom_fields={}, traces_copy_out=None, deactivate_file_download_exceptions=False):
def download_from_metalink_file(
self,
item: dict[str, Any],
metalink_file_path: str,
num_threads: int = 2,
trace_custom_fields: Optional[dict[str, Any]] = None,
traces_copy_out: Optional[list[dict[str, Any]]] = None,
deactivate_file_download_exceptions: bool = False
) -> list[dict[str, Any]]:
"""
Download items using a given metalink file.
Expand All @@ -327,6 +352,7 @@ def download_from_metalink_file(self, item, metalink_file_path, num_threads=2, t
:raises NotAllFilesDownloaded: if not all files could be downloaded
:raises RucioException: if something unexpected went wrong during the download
"""
trace_custom_fields = trace_custom_fields or {}
logger = self.logger

logger(logging.INFO, 'Getting sources from metalink file')
Expand All @@ -351,7 +377,13 @@ def download_from_metalink_file(self, item, metalink_file_path, num_threads=2, t

return self._check_output(output_items, deactivate_file_download_exceptions=deactivate_file_download_exceptions)

def _download_multithreaded(self, input_items, num_threads, trace_custom_fields={}, traces_copy_out=None):
def _download_multithreaded(
self,
input_items: list[dict[str, Any]],
num_threads: int,
trace_custom_fields: Optional[dict[str, Any]] = None,
traces_copy_out: Optional[list[dict[str, Any]]] = None
) -> list[dict[str, Any]]:
"""
Starts an appropriate number of threads to download items from the input list.
(This function is meant to be used as class internal only)
Expand All @@ -363,6 +395,7 @@ def _download_multithreaded(self, input_items, num_threads, trace_custom_fields=
:returns: list with output items as dictionaries
"""
trace_custom_fields = trace_custom_fields or {}
logger = self.logger

num_files = len(input_items)
Expand Down Expand Up @@ -730,7 +763,14 @@ def _download_item(self, item, trace, traces_copy_out, log_prefix=''):

return item

def download_aria2c(self, items, trace_custom_fields={}, filters={}, deactivate_file_download_exceptions=False, sort=None):
def download_aria2c(
self,
items: list[dict[str, Any]],
trace_custom_fields: Optional[dict[str, Any]] = None,
filters: Optional[dict[str, Any]] = None,
deactivate_file_download_exceptions: bool = False,
sort: Optional[str] = None
) -> list[dict[str, Any]]:
"""
Uses aria2c to download the items with given DIDs. This function can also download datasets and wildcarded DIDs.
It only can download files that are available via https/davs.
Expand Down Expand Up @@ -760,6 +800,8 @@ def download_aria2c(self, items, trace_custom_fields={}, filters={}, deactivate_
:raises NotAllFilesDownloaded: if not all files could be downloaded
:raises RucioException: if something went wrong during the download (e.g. aria2c could not be started)
"""
trace_custom_fields = trace_custom_fields or {}
filters = filters or {}
logger = self.logger
trace_custom_fields['uuid'] = generate_uuid()

Expand Down Expand Up @@ -860,7 +902,13 @@ def _start_aria2c_rpc(self, rpc_secret):
raise RucioException('Failed to initialise rpc proxy!', error)
return (rpcproc, aria_rpc)

def _download_items_aria2c(self, items, aria_rpc, rpc_auth, trace_custom_fields={}):
def _download_items_aria2c(
self,
items: list[dict[str, Any]],
aria_rpc: Any,
rpc_auth: str,
trace_custom_fields: Optional[dict[str, Any]] = None
) -> list[dict[str, Any]]:
"""
Uses aria2c to download the given items. Aria2c needs to be started
as RPC background process first and a RPC proxy is needed.
Expand All @@ -873,6 +921,7 @@ def _download_items_aria2c(self, items, aria_rpc, rpc_auth, trace_custom_fields=
:returns: a list of dictionaries with an entry for each file, containing the input options, the did, and the clientState
"""
trace_custom_fields = trace_custom_fields or {}
logger = self.logger

gid_to_item = {} # maps an aria2c download id (gid) to the download item
Expand Down

0 comments on commit 1b631a7

Please sign in to comment.