Skip to content

Commit

Permalink
Auth: Refactor to use TokenDict type and add type hints; #6454
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio committed Feb 19, 2024
1 parent 03cd7c6 commit e332726
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 89 deletions.
6 changes: 6 additions & 0 deletions lib/rucio/common/types.py
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime
from typing import Any, Callable, Optional, TypedDict, Union


Expand Down Expand Up @@ -159,3 +160,8 @@ class RSESettingsDict(TypedDict):
deterministic: bool
domain: list[str]
protocols: list[RSEProtocolDict]


class TokenDict(TypedDict):
token: str
expires_at: datetime
4 changes: 2 additions & 2 deletions lib/rucio/common/utils.py
Expand Up @@ -40,7 +40,7 @@
from functools import partial, wraps
from io import StringIO
from itertools import zip_longest
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse, urlencode, quote, parse_qsl, urlunparse
from uuid import uuid4 as uuid
from xml.etree import ElementTree
Expand Down Expand Up @@ -570,7 +570,7 @@ def str_to_date(string):
return datetime.datetime.strptime(string, DATE_FORMAT) if string else None


def val_to_space_sep_str(vallist):
def val_to_space_sep_str(vallist: Any) -> str:
""" Converts a list of values into a string of space separated values
:param vallist: the list of values to to convert into string
Expand Down
52 changes: 25 additions & 27 deletions lib/rucio/core/account.py
Expand Up @@ -17,9 +17,7 @@
from enum import Enum
from re import match
from traceback import format_exc
from typing import TYPE_CHECKING, Any
from collections.abc import Generator
import uuid
from typing import TYPE_CHECKING, Any, Iterator, Optional

from sqlalchemy import select, and_
from sqlalchemy.exc import IntegrityError
Expand All @@ -40,7 +38,7 @@


@transactional_session
def add_account(account: InternalAccount, type_: AccountType, email: str, *, session: "Session"):
def add_account(account: InternalAccount, type_: AccountType, email: str, *, session: "Session") -> None:
""" Add an account with the given account name and type.
:param account: the name of the new account.
Expand Down Expand Up @@ -84,7 +82,7 @@ def account_exists(account: InternalAccount, *, session: "Session") -> bool:


@read_session
def get_account(account: InternalAccount, *, session: "Session") -> dict:
def get_account(account: InternalAccount, *, session: "Session") -> models.Account:
""" Returns an account for the given account name.
:param account: the name of the account.
Expand All @@ -105,7 +103,7 @@ def get_account(account: InternalAccount, *, session: "Session") -> dict:


@transactional_session
def del_account(account: InternalAccount, *, session: "Session"):
def del_account(account: InternalAccount, *, session: "Session") -> None:
""" Disable an account with the given account name.
:param account: the account name.
Expand All @@ -118,15 +116,15 @@ def del_account(account: InternalAccount, *, session: "Session"):
models.Account.status == AccountStatus.ACTIVE
)
try:
account = session.execute(query).scalar_one()
account_result = session.execute(query).scalar_one()
except exc.NoResultFound:
raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account)

account.update({'status': AccountStatus.DELETED, 'deleted_at': datetime.utcnow()})
account_result.update({'status': AccountStatus.DELETED, 'deleted_at': datetime.utcnow()})


@transactional_session
def update_account(account: InternalAccount, key: str, value: Any, *, session: "Session"):
def update_account(account: InternalAccount, key: str, value: Any, *, session: "Session") -> None:
""" Update a property of an account.
:param account: Name of the account.
Expand All @@ -140,22 +138,22 @@ def update_account(account: InternalAccount, key: str, value: Any, *, session: "
models.Account.account == account
)
try:
account = session.execute(query).scalar_one()
account_result = session.execute(query).scalar_one()
if key == 'status':
if isinstance(value, str):
value = AccountStatus[value]
if value == AccountStatus.SUSPENDED:
account_result.update({'status': value, 'suspended_at': datetime.utcnow()})
elif value == AccountStatus.ACTIVE:
account_result.update({'status': value, 'suspended_at': None})
else:
account_result.update({key: value})
except exc.NoResultFound:
raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account)
if key == 'status':
if isinstance(value, str):
value = AccountStatus[value]
if value == AccountStatus.SUSPENDED:
account.update({'status': value, 'suspended_at': datetime.utcnow()})
elif value == AccountStatus.ACTIVE:
account.update({'status': value, 'suspended_at': None})
else:
account.update({key: value})


@stream_session
def list_accounts(filter_: dict = None, *, session: "Session") -> Generator[dict]:
def list_accounts(filter_: Optional[dict[str, Any]] = None, *, session: "Session") -> Iterator[dict]:
""" Returns a list of all account names.
:param filter_: Dictionary of attributes by which the input data should be filtered
Expand Down Expand Up @@ -216,7 +214,7 @@ def list_accounts(filter_: dict = None, *, session: "Session") -> Generator[dict


@read_session
def list_identities(account: InternalAccount, *, session: "Session"):
def list_identities(account: InternalAccount, *, session: "Session") -> list[dict[str, Any]]:
"""
List all identities on an account.
Expand Down Expand Up @@ -251,7 +249,7 @@ def list_identities(account: InternalAccount, *, session: "Session"):


@read_session
def list_account_attributes(account: InternalAccount, *, session: "Session") -> list[dict]:
def list_account_attributes(account: InternalAccount, *, session: "Session") -> list[dict[str, Any]]:
"""
Get all attributes defined for an account.
Expand Down Expand Up @@ -301,7 +299,7 @@ def has_account_attribute(account: InternalAccount, key: str, *, session: "Sessi


@transactional_session
def add_account_attribute(account: InternalAccount, key: str, value: Any, *, session: "Session"):
def add_account_attribute(account: InternalAccount, key: str, value: Any, *, session: "Session") -> None:
"""
Add an attribute for the given account name.
Expand Down Expand Up @@ -337,7 +335,7 @@ def add_account_attribute(account: InternalAccount, key: str, value: Any, *, ses


@transactional_session
def del_account_attribute(account: InternalAccount, key: str, *, session: "Session"):
def del_account_attribute(account: InternalAccount, key: str, *, session: "Session") -> None:
"""
Add an attribute for the given account name.
Expand All @@ -358,7 +356,7 @@ def del_account_attribute(account: InternalAccount, key: str, *, session: "Sessi


@read_session
def get_usage(rse_id: uuid.UUID, account: InternalAccount, *, session: "Session") -> dict:
def get_usage(rse_id: str, account: InternalAccount, *, session: "Session") -> dict:
"""
Returns current values of the specified counter, or raises CounterNotFound if the counter does not exist.
Expand Down Expand Up @@ -403,14 +401,14 @@ def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Sessio


@read_session
def get_usage_history(rse_id: uuid.UUID, account: InternalAccount, *, session: "Session") -> dict:
def get_usage_history(rse_id: str, account: InternalAccount, *, session: "Session") -> list[dict]:
"""
Returns historical values of the specified counter, or raises CounterNotFound if the counter does not exist.
:param rse_id: The id of the RSE.
:param account: The account name.
:param session: The database session in use.
:returns: A dictionary {'bytes', 'files', 'updated_at'}
:returns: A list of dictionaries {'bytes', 'files', 'updated_at'}
"""
query = select(
models.AccountUsageHistory.bytes,
Expand Down
17 changes: 8 additions & 9 deletions lib/rucio/core/account_counter.py
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid
import datetime
from typing import TYPE_CHECKING

Expand All @@ -25,13 +24,13 @@
from rucio.db.sqla.session import read_session, transactional_session

if TYPE_CHECKING:
from sqlalchemy.orm import Session, Query
from sqlalchemy.orm import Session

MAX_COUNTERS = 10


@transactional_session
def add_counter(rse_id: uuid.UUID, account: InternalAccount, *, session: "Session"):
def add_counter(rse_id: str, account: InternalAccount, *, session: "Session"):
"""
Creates the specified counter for a rse_id and account.
Expand All @@ -44,7 +43,7 @@ def add_counter(rse_id: uuid.UUID, account: InternalAccount, *, session: "Sessio


@transactional_session
def increase(rse_id: uuid.UUID, account: InternalAccount, files: int, bytes_: int, *, session: "Session"):
def increase(rse_id: str, account: InternalAccount, files: int, bytes_: int, *, session: "Session"):
"""
Increments the specified counter by the specified amount.
Expand All @@ -58,7 +57,7 @@ def increase(rse_id: uuid.UUID, account: InternalAccount, files: int, bytes_: in


@transactional_session
def decrease(rse_id: uuid.UUID, account: InternalAccount, files: int, bytes_: int, *, session: "Session") -> None:
def decrease(rse_id: str, account: InternalAccount, files: int, bytes_: int, *, session: "Session") -> None:
"""
Decreases the specified counter by the specified amount.
Expand All @@ -72,7 +71,7 @@ def decrease(rse_id: uuid.UUID, account: InternalAccount, files: int, bytes_: in


@transactional_session
def del_counter(rse_id: uuid.UUID, account: InternalAccount, *, session: "Session"):
def del_counter(rse_id: str, account: InternalAccount, *, session: "Session") -> None:
"""
Resets the specified counter and initializes it by the specified amounts.
Expand All @@ -85,7 +84,7 @@ def del_counter(rse_id: uuid.UUID, account: InternalAccount, *, session: "Sessio


@read_session
def get_updated_account_counters(total_workers: int, worker_number: int, *, session: "Session") -> list["Query"]:
def get_updated_account_counters(total_workers: int, worker_number: int, *, session: "Session") -> list[tuple[InternalAccount, str]]:
"""
Get updated rse_counters.
Expand All @@ -108,7 +107,7 @@ def get_updated_account_counters(total_workers: int, worker_number: int, *, sess


@transactional_session
def update_account_counter(account: InternalAccount, rse_id: uuid.UUID, *, session: "Session"):
def update_account_counter(account: InternalAccount, rse_id: str, *, session: "Session") -> None:
"""
Read the updated_account_counters and update the account_counter.
Expand All @@ -134,7 +133,7 @@ def update_account_counter(account: InternalAccount, rse_id: uuid.UUID, *, sessi


@transactional_session
def update_account_counter_history(account: InternalAccount, rse_id: uuid.UUID, *, session: "Session"):
def update_account_counter_history(account: InternalAccount, rse_id: str, *, session: "Session") -> None:
"""
Read the AccountUsage and update the AccountUsageHistory.
Expand Down
20 changes: 10 additions & 10 deletions lib/rucio/core/account_limit.py
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import uuid
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.sql import func, select, literal
Expand All @@ -32,7 +32,7 @@


@read_session
def get_rse_account_usage(rse_id: uuid.UUID, *, session: "Session") -> list[dict]:
def get_rse_account_usage(rse_id: str, *, session: "Session") -> list[dict]:
"""
Returns the account limit and usage for all accounts on a RSE.
Expand Down Expand Up @@ -92,7 +92,7 @@ def get_rse_account_usage(rse_id: uuid.UUID, *, session: "Session") -> list[dict


@read_session
def get_global_account_limits(account: InternalAccount = None, *, session: "Session") -> dict:
def get_global_account_limits(account: Optional[InternalAccount] = None, *, session: "Session") -> dict:
"""
Returns the global account limits for the account.
Expand Down Expand Up @@ -123,7 +123,7 @@ def get_global_account_limits(account: InternalAccount = None, *, session: "Sess


@read_session
def get_global_account_limit(account: InternalAccount, rse_expression: str, *, session: "Session") -> int:
def get_global_account_limit(account: InternalAccount, rse_expression: str, *, session: "Session") -> Union[int, float, None]:
"""
Returns the global account limit for the account on the rse expression.
Expand All @@ -143,7 +143,7 @@ def get_global_account_limit(account: InternalAccount, rse_expression: str, *, s


@read_session
def get_local_account_limit(account: InternalAccount, rse_id: uuid.UUID, *, session: "Session") -> Optional[int | float]:
def get_local_account_limit(account: InternalAccount, rse_id: str, *, session: "Session") -> Union[int, float, None]:
"""
Returns the account limit for the account on the rse.
Expand All @@ -164,7 +164,7 @@ def get_local_account_limit(account: InternalAccount, rse_id: uuid.UUID, *, sess


@read_session
def get_local_account_limits(account: InternalAccount, rse_ids: list[uuid.UUID] = None, *, session: "Session") -> dict:
def get_local_account_limits(account: InternalAccount, rse_ids: Optional[list[str]] = None, *, session: "Session") -> dict:
"""
Returns the account limits for the account on the list of rses.
Expand Down Expand Up @@ -198,7 +198,7 @@ def get_local_account_limits(account: InternalAccount, rse_ids: list[uuid.UUID]


@transactional_session
def set_local_account_limit(account: InternalAccount, rse_id: uuid.UUID, bytes_: int, *, session: "Session"):
def set_local_account_limit(account: InternalAccount, rse_id: str, bytes_: int, *, session: "Session"):
"""
Returns the limits for the account on the rse.
Expand Down Expand Up @@ -234,7 +234,7 @@ def set_global_account_limit(account: InternalAccount, rse_expression: str, byte


@transactional_session
def delete_local_account_limit(account: InternalAccount, rse_id: uuid.UUID, *, session: "Session") -> bool:
def delete_local_account_limit(account: InternalAccount, rse_id: str, *, session: "Session") -> bool:
"""
Deletes a local account limit.
Expand Down Expand Up @@ -270,7 +270,7 @@ def delete_global_account_limit(account: InternalAccount, rse_expression: str, *


@transactional_session
def get_local_account_usage(account: InternalAccount, rse_id: uuid.UUID = None, *, session: "Session") -> list[dict]:
def get_local_account_usage(account: InternalAccount, rse_id: Optional[uuid.UUID] = None, *, session: "Session") -> list[dict]:
"""
Read the account usage and connect it with (if available) the account limits of the account.
Expand Down Expand Up @@ -313,7 +313,7 @@ def get_local_account_usage(account: InternalAccount, rse_id: uuid.UUID = None,


@transactional_session
def get_global_account_usage(account: InternalAccount, rse_expression: str = None, *, session: "Session") -> list[dict]:
def get_global_account_usage(account: InternalAccount, rse_expression: Optional[str] = None, *, session: "Session") -> list[dict]:
"""
Read the account usage and connect it with the global account limits of the account.
Expand Down

0 comments on commit e332726

Please sign in to comment.