Skip to content

Commit

Permalink
Typing: Add type hints; #6454
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio committed Mar 7, 2024
1 parent e966cf8 commit 9870978
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 104 deletions.
4 changes: 2 additions & 2 deletions lib/rucio/common/utils.py
Expand Up @@ -40,7 +40,7 @@
from functools import partial, wraps
from io import StringIO
from itertools import zip_longest
from typing import TYPE_CHECKING, Callable, Optional
from typing import Any, TYPE_CHECKING, Callable, Optional
from urllib.parse import urlparse, urlencode, quote, parse_qsl, urlunparse
from uuid import uuid4 as uuid
from xml.etree import ElementTree
Expand Down Expand Up @@ -568,7 +568,7 @@ def str_to_date(string):
return datetime.datetime.strptime(string, DATE_FORMAT) if string else None


def val_to_space_sep_str(vallist):
def val_to_space_sep_str(vallist: Any) -> str:
""" Converts a list of values into a string of space separated values
:param vallist: the list of values to to convert into string
Expand Down
58 changes: 29 additions & 29 deletions lib/rucio/core/account.py
Expand Up @@ -17,7 +17,7 @@
from enum import Enum
from re import match
from traceback import format_exc
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, cast, Iterator, Optional

from sqlalchemy import select, and_
from sqlalchemy.exc import IntegrityError
Expand All @@ -27,6 +27,7 @@
import rucio.core.rse
from rucio.common import exception
from rucio.common.config import config_get_bool
from rucio.common.types import InternalAccount, AccountAttributesDict, AccountDict, AccountUsageModelDict, IdentityDict, UsageDict
from rucio.core.vo import vo_exists
from rucio.db.sqla import models
from rucio.db.sqla.constants import AccountStatus, AccountType
Expand All @@ -37,7 +38,7 @@


@transactional_session
def add_account(account, type_, email, *, session: "Session"):
def add_account(account: InternalAccount, type_: AccountType, email: str, *, session: "Session") -> None:
""" Add an account with the given account name and type.
:param account: the name of the new account.
Expand All @@ -63,7 +64,7 @@ def add_account(account, type_, email, *, session: "Session"):


@read_session
def account_exists(account, *, session: "Session"):
def account_exists(account: InternalAccount, *, session: "Session") -> bool:
""" Checks to see if account exists and is active.
:param account: Name of the account.
Expand All @@ -81,7 +82,7 @@ def account_exists(account, *, session: "Session"):


@read_session
def get_account(account, *, session: "Session"):
def get_account(account: InternalAccount, *, session: "Session") -> models.Account:
""" Returns an account for the given account name.
:param account: the name of the account.
Expand All @@ -102,7 +103,7 @@ def get_account(account, *, session: "Session"):


@transactional_session
def del_account(account, *, session: "Session"):
def del_account(account: InternalAccount, *, session: "Session") -> None:
""" Disable an account with the given account name.
:param account: the account name.
Expand All @@ -115,15 +116,15 @@ def del_account(account, *, session: "Session"):
models.Account.status == AccountStatus.ACTIVE
)
try:
account = session.execute(query).scalar_one()
account_result = session.execute(query).scalar_one()
except exc.NoResultFound:
raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account)

account.update({'status': AccountStatus.DELETED, 'deleted_at': datetime.utcnow()})
account_result.update({'status': AccountStatus.DELETED, 'deleted_at': datetime.utcnow()})


@transactional_session
def update_account(account, key, value, *, session: "Session"):
def update_account(account: InternalAccount, key: str, value: Any, *, session: "Session") -> None:
""" Update a property of an account.
:param account: Name of the account.
Expand All @@ -137,22 +138,22 @@ def update_account(account, key, value, *, session: "Session"):
models.Account.account == account
)
try:
account = session.execute(query).scalar_one()
query_result = session.execute(query).scalar_one()
except exc.NoResultFound:
raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account)
if key == 'status':
if isinstance(value, str):
value = AccountStatus[value]
if value == AccountStatus.SUSPENDED:
account.update({'status': value, 'suspended_at': datetime.utcnow()})
query_result.update({'status': value, 'suspended_at': datetime.utcnow()})
elif value == AccountStatus.ACTIVE:
account.update({'status': value, 'suspended_at': None})
query_result.update({'status': value, 'suspended_at': None})
else:
account.update({key: value})
query_result.update({key: value})


@stream_session
def list_accounts(filter_=None, *, session: "Session"):
def list_accounts(filter_: Optional[dict[str, Any]] = None, *, session: "Session") -> Iterator[AccountDict]:
""" Returns a list of all account names.
:param filter_: Dictionary of attributes by which the input data should be filtered
Expand Down Expand Up @@ -213,7 +214,7 @@ def list_accounts(filter_=None, *, session: "Session"):


@read_session
def list_identities(account, *, session: "Session"):
def list_identities(account: InternalAccount, *, session: "Session") -> list[IdentityDict]:
"""
List all identities on an account.
Expand Down Expand Up @@ -244,11 +245,11 @@ def list_identities(account, *, session: "Session"):
).where(
models.IdentityAccountAssociation.account == account
)
return [row._asdict() for row in session.execute(query)]
return [cast(IdentityDict, row._asdict()) for row in session.execute(query)]


@read_session
def list_account_attributes(account, *, session: "Session"):
def list_account_attributes(account: InternalAccount, *, session: "Session") -> list[AccountAttributesDict]:
"""
Get all attributes defined for an account.
Expand All @@ -274,11 +275,11 @@ def list_account_attributes(account, *, session: "Session"):
).where(
models.AccountAttrAssociation.account == account
)
return [row._asdict() for row in session.execute(query)]
return [cast(AccountAttributesDict, row._asdict()) for row in session.execute(query)]


@read_session
def has_account_attribute(account, key, *, session: "Session"):
def has_account_attribute(account: InternalAccount, key: str, *, session: "Session") -> bool:
"""
Indicates whether the named key is present for the account.
Expand All @@ -298,7 +299,7 @@ def has_account_attribute(account, key, *, session: "Session"):


@transactional_session
def add_account_attribute(account, key, value, *, session: "Session"):
def add_account_attribute(account: InternalAccount, key: str, value: Any, *, session: "Session") -> None:
"""
Add an attribute for the given account name.
Expand Down Expand Up @@ -334,7 +335,7 @@ def add_account_attribute(account, key, value, *, session: "Session"):


@transactional_session
def del_account_attribute(account, key, *, session: "Session"):
def del_account_attribute(account: InternalAccount, key: str, *, session: "Session") -> None:
"""
Add an attribute for the given account name.
Expand All @@ -355,7 +356,7 @@ def del_account_attribute(account, key, *, session: "Session"):


@read_session
def get_usage(rse_id, account, *, session: "Session"):
def get_usage(rse_id: str, account: InternalAccount, *, session: "Session") -> UsageDict:
"""
Returns current values of the specified counter, or raises CounterNotFound if the counter does not exist.
Expand All @@ -373,13 +374,13 @@ def get_usage(rse_id, account, *, session: "Session"):
models.AccountUsage.account == account
)
try:
return session.execute(query).one()._asdict()
return cast(UsageDict, session.execute(query).one()._asdict())
except exc.NoResultFound:
return {'bytes': 0, 'files': 0, 'updated_at': None}
return UsageDict({'bytes': 0, 'files': 0, 'updated_at': None})


@read_session
def get_all_rse_usages_per_account(account, *, session: "Session"):
def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Session") -> list[AccountUsageModelDict]:
"""
Returns current values of the specified counter, or raises CounterNotFound if the counter does not exist.
Expand All @@ -394,20 +395,20 @@ def get_all_rse_usages_per_account(account, *, session: "Session"):
models.AccountUsage.account == account
)
try:
return [result.to_dict() for result in session.execute(query).scalars()]
return [cast(AccountUsageModelDict, result.to_dict()) for result in session.execute(query).scalars()]
except exc.NoResultFound:
return []


@read_session
def get_usage_history(rse_id, account, *, session: "Session"):
def get_usage_history(rse_id: str, account: InternalAccount, *, session: "Session") -> list[UsageDict]:
"""
Returns historical values of the specified counter, or raises CounterNotFound if the counter does not exist.
:param rse_id: The id of the RSE.
:param account: The account name.
:param session: The database session in use.
:returns: A dictionary {'bytes', 'files', 'updated_at'}
:returns: A list of dictionaries {'bytes', 'files', 'updated_at'}
"""
query = select(
models.AccountUsageHistory.bytes,
Expand All @@ -420,7 +421,6 @@ def get_usage_history(rse_id, account, *, session: "Session"):
models.AccountUsageHistory.updated_at
)
try:
return [row._asdict() for row in session.execute(query)]
return [cast(UsageDict, row._asdict()) for row in session.execute(query)]
except exc.NoResultFound:
raise exception.CounterNotFound('No usage can be found for account %s on RSE %s' % (account, rucio.core.rse.get_rse_name(rse_id=rse_id, session=session)))
return []
20 changes: 11 additions & 9 deletions lib/rucio/core/account_counter.py
Expand Up @@ -12,12 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
from typing import TYPE_CHECKING
from typing import cast, TYPE_CHECKING

from sqlalchemy import literal, insert, select
from sqlalchemy.orm.exc import NoResultFound

from rucio.common.types import InternalAccount, RSEAccountCounterDict
from rucio.db.sqla import models, filter_thread_work
from rucio.db.sqla.session import read_session, transactional_session

Expand All @@ -28,7 +30,7 @@


@transactional_session
def add_counter(rse_id, account, *, session: "Session"):
def add_counter(rse_id: str, account: InternalAccount, *, session: "Session") -> None:
"""
Creates the specified counter for a rse_id and account.
Expand All @@ -41,7 +43,7 @@ def add_counter(rse_id, account, *, session: "Session"):


@transactional_session
def increase(rse_id, account, files, bytes_, *, session: "Session"):
def increase(rse_id: str, account: InternalAccount, files: int, bytes_: int, *, session: "Session") -> None:
"""
Increments the specified counter by the specified amount.
Expand All @@ -55,7 +57,7 @@ def increase(rse_id, account, files, bytes_, *, session: "Session"):


@transactional_session
def decrease(rse_id, account, files, bytes_, *, session: "Session"):
def decrease(rse_id: str, account: InternalAccount, files: int, bytes_: int, *, session: "Session") -> None:
"""
Decreases the specified counter by the specified amount.
Expand All @@ -69,7 +71,7 @@ def decrease(rse_id, account, files, bytes_, *, session: "Session"):


@transactional_session
def del_counter(rse_id, account, *, session: "Session"):
def del_counter(rse_id: str, account: InternalAccount, *, session: "Session") -> None:
"""
Resets the specified counter and initializes it by the specified amounts.
Expand All @@ -82,7 +84,7 @@ def del_counter(rse_id, account, *, session: "Session"):


@read_session
def get_updated_account_counters(total_workers, worker_number, *, session: "Session"):
def get_updated_account_counters(total_workers: int, worker_number: int, *, session: "Session") -> list[RSEAccountCounterDict]:
"""
Get updated rse_counters.
Expand All @@ -104,7 +106,7 @@ def get_updated_account_counters(total_workers, worker_number, *, session: "Sess


@transactional_session
def update_account_counter(account, rse_id, *, session: "Session"):
def update_account_counter(account: InternalAccount, rse_id: str, *, session: "Session") -> None:
"""
Read the updated_account_counters and update the account_counter.
Expand All @@ -130,7 +132,7 @@ def update_account_counter(account, rse_id, *, session: "Session"):


@transactional_session
def update_account_counter_history(account, rse_id, *, session: "Session"):
def update_account_counter_history(account: InternalAccount, rse_id: str, *, session: "Session") -> None:
"""
Read the AccountUsage and update the AccountUsageHistory.
Expand All @@ -146,7 +148,7 @@ def update_account_counter_history(account, rse_id, *, session: "Session"):


@transactional_session
def fill_account_counter_history_table(*, session: "Session"):
def fill_account_counter_history_table(*, session: "Session") -> None:
"""
Make a snapshot of current counters
Expand Down

0 comments on commit 9870978

Please sign in to comment.