Skip to content

Commit

Permalink
Auth: Refactor to use custom dict types and fix type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio committed Feb 28, 2024
1 parent 3da811e commit 5ce0642
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 73 deletions.
91 changes: 91 additions & 0 deletions lib/rucio/common/types.py
Expand Up @@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from uuid import UUID
from datetime import datetime
from typing import Any, Callable, Optional, TypedDict, Union

from rucio.db.sqla.constants import AccountType, IdentityType


class InternalType(object):
'''
Expand Down Expand Up @@ -162,6 +165,94 @@ class RSESettingsDict(TypedDict):
protocols: list[RSEProtocolDict]


class RSEAccountCounterDict(TypedDict):
account: InternalAccount
rse_id: UUID


class RSEAccountUsageDict(TypedDict):
rse_id: UUID
rse: str
account: InternalAccount
used_files: int
used_bytes: int
quota_bytes: int


class RSEGlobalAccountUsageDict(TypedDict):
rse_expression: str
bytes: int
files: int
bytes_limit: int
bytes_remaining: int


class RSELocalAccountUsageDict(TypedDict):
rse_id: UUID
rse: str
bytes: int
files: int
bytes_limit: int
bytes_remaining: int


class RSEResolvedGlobalAccountLimitDict(TypedDict):
resolved_rses: str
resolved_rse_ids: list[UUID]
limit: float


class TokenDict(TypedDict):
token: str
expires_at: datetime


class TokenValidationDict(TypedDict):
account: Optional[InternalAccount]
identity: Optional[str]
lifetime: Optional[datetime]
audience: Optional[str]
authz_scope: Optional[str]


class AccountDict(TypedDict):
account: InternalAccount
type: AccountType
email: str


class AccountAttributesDict(TypedDict):
key: str
value: Union[bool, str]


class IdentityDict(TypedDict):
type: IdentityType
identity: str
email: str


class UsageDict(TypedDict):
bytes: int
files: int
updated_at: Optional[datetime]


class TokenOIDCAutoDict(TypedDict, total=False):
webhome: str
token: TokenDict


class TokenOIDCNoAutoDict(TypedDict):
fetchcode: str


class TokenOIDCPollingDict(TypedDict):
polling: bool


class AccountUsageModelDict(TypedDict):
account: InternalAccount
rse_id: UUID
files: int
bytes: int
4 changes: 2 additions & 2 deletions lib/rucio/common/utils.py
Expand Up @@ -40,7 +40,7 @@
from functools import partial, wraps
from io import StringIO
from itertools import zip_longest
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Union
from urllib.parse import urlparse, urlencode, quote, parse_qsl, urlunparse
from uuid import uuid4 as uuid
from xml.etree import ElementTree
Expand Down Expand Up @@ -1200,7 +1200,7 @@ def detect_client_location():
'longitude': longitude}


def ssh_sign(private_key: str, message: str) -> str:
def ssh_sign(private_key: str, message: Union[str, bytes]) -> str:
"""
Sign a string message using the private key.
Expand Down
48 changes: 24 additions & 24 deletions lib/rucio/core/account.py
Expand Up @@ -17,7 +17,7 @@
from enum import Enum
from re import match
from traceback import format_exc
from typing import TYPE_CHECKING, Any, Iterator, Optional
from typing import TYPE_CHECKING, Any, cast, Iterator, Optional

from sqlalchemy import select, and_
from sqlalchemy.exc import IntegrityError
Expand All @@ -27,7 +27,7 @@
import rucio.core.rse
from rucio.common import exception
from rucio.common.config import config_get_bool
from rucio.common.types import InternalAccount
from rucio.common.types import InternalAccount, AccountAttributesDict, AccountDict, AccountUsageModelDict, IdentityDict, UsageDict
from rucio.core.vo import vo_exists
from rucio.db.sqla import models
from rucio.db.sqla.constants import AccountStatus, AccountType
Expand Down Expand Up @@ -138,22 +138,22 @@ def update_account(account: InternalAccount, key: str, value: Any, *, session: "
models.Account.account == account
)
try:
account_result = session.execute(query).scalar_one()
if key == 'status':
if isinstance(value, str):
value = AccountStatus[value]
if value == AccountStatus.SUSPENDED:
account_result.update({'status': value, 'suspended_at': datetime.utcnow()})
elif value == AccountStatus.ACTIVE:
account_result.update({'status': value, 'suspended_at': None})
else:
account_result.update({key: value})
query_result = session.execute(query).scalar_one()
except exc.NoResultFound:
raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account)
if key == 'status':
if isinstance(value, str):
value = AccountStatus[value]
if value == AccountStatus.SUSPENDED:
query_result.update({'status': value, 'suspended_at': datetime.utcnow()})
elif value == AccountStatus.ACTIVE:
query_result.update({'status': value, 'suspended_at': None})
else:
query_result.update({key: value})


@stream_session
def list_accounts(filter_: Optional[dict[str, Any]] = None, *, session: "Session") -> Iterator[dict]:
def list_accounts(filter_: Optional[dict[str, Any]] = None, *, session: "Session") -> Iterator[AccountDict]:
""" Returns a list of all account names.
:param filter_: Dictionary of attributes by which the input data should be filtered
Expand Down Expand Up @@ -214,7 +214,7 @@ def list_accounts(filter_: Optional[dict[str, Any]] = None, *, session: "Session


@read_session
def list_identities(account: InternalAccount, *, session: "Session") -> list[dict[str, Any]]:
def list_identities(account: InternalAccount, *, session: "Session") -> list[IdentityDict]:
"""
List all identities on an account.
Expand Down Expand Up @@ -245,11 +245,11 @@ def list_identities(account: InternalAccount, *, session: "Session") -> list[dic
).where(
models.IdentityAccountAssociation.account == account
)
return [row._asdict() for row in session.execute(query)]
return [cast(IdentityDict, row._asdict()) for row in session.execute(query)]


@read_session
def list_account_attributes(account: InternalAccount, *, session: "Session") -> list[dict[str, Any]]:
def list_account_attributes(account: InternalAccount, *, session: "Session") -> list[AccountAttributesDict]:
"""
Get all attributes defined for an account.
Expand All @@ -275,7 +275,7 @@ def list_account_attributes(account: InternalAccount, *, session: "Session") ->
).where(
models.AccountAttrAssociation.account == account
)
return [row._asdict() for row in session.execute(query)]
return [cast(AccountAttributesDict, row._asdict()) for row in session.execute(query)]


@read_session
Expand Down Expand Up @@ -356,7 +356,7 @@ def del_account_attribute(account: InternalAccount, key: str, *, session: "Sessi


@read_session
def get_usage(rse_id: str, account: InternalAccount, *, session: "Session") -> dict:
def get_usage(rse_id: str, account: InternalAccount, *, session: "Session") -> UsageDict:
"""
Returns current values of the specified counter, or raises CounterNotFound if the counter does not exist.
Expand All @@ -374,13 +374,13 @@ def get_usage(rse_id: str, account: InternalAccount, *, session: "Session") -> d
models.AccountUsage.account == account
)
try:
return session.execute(query).one()._asdict()
return cast(UsageDict, session.execute(query).one()._asdict())
except exc.NoResultFound:
return {'bytes': 0, 'files': 0, 'updated_at': None}
return UsageDict({'bytes': 0, 'files': 0, 'updated_at': None})


@read_session
def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Session") -> list[dict]:
def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Session") -> list[AccountUsageModelDict]:
"""
Returns current values of the specified counter, or raises CounterNotFound if the counter does not exist.
Expand All @@ -395,13 +395,13 @@ def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Sessio
models.AccountUsage.account == account
)
try:
return [result.to_dict() for result in session.execute(query).scalars()]
return [cast(AccountUsageModelDict, result.to_dict()) for result in session.execute(query).scalars()]
except exc.NoResultFound:
return []


@read_session
def get_usage_history(rse_id: str, account: InternalAccount, *, session: "Session") -> list[dict]:
def get_usage_history(rse_id: str, account: InternalAccount, *, session: "Session") -> list[UsageDict]:
"""
Returns historical values of the specified counter, or raises CounterNotFound if the counter does not exist.
Expand All @@ -421,6 +421,6 @@ def get_usage_history(rse_id: str, account: InternalAccount, *, session: "Sessio
models.AccountUsageHistory.updated_at
)
try:
return [row._asdict() for row in session.execute(query)]
return [cast(UsageDict, row._asdict()) for row in session.execute(query)]
except exc.NoResultFound:
raise exception.CounterNotFound('No usage can be found for account %s on RSE %s' % (account, rucio.core.rse.get_rse_name(rse_id=rse_id, session=session)))
14 changes: 7 additions & 7 deletions lib/rucio/core/account_counter.py
Expand Up @@ -14,12 +14,12 @@
# limitations under the License.

import datetime
from typing import TYPE_CHECKING, Any
from typing import cast, TYPE_CHECKING

from sqlalchemy import literal, insert, select
from sqlalchemy.orm.exc import NoResultFound

from rucio.common.types import InternalAccount
from rucio.common.types import InternalAccount, RSEAccountCounterDict
from rucio.db.sqla import models, filter_thread_work
from rucio.db.sqla.session import read_session, transactional_session

Expand All @@ -30,7 +30,7 @@


@transactional_session
def add_counter(rse_id: str, account: InternalAccount, *, session: "Session"):
def add_counter(rse_id: str, account: InternalAccount, *, session: "Session") -> None:
"""
Creates the specified counter for a rse_id and account.
Expand All @@ -43,7 +43,7 @@ def add_counter(rse_id: str, account: InternalAccount, *, session: "Session"):


@transactional_session
def increase(rse_id: str, account: InternalAccount, files: int, bytes_: int, *, session: "Session"):
def increase(rse_id: str, account: InternalAccount, files: int, bytes_: int, *, session: "Session") -> None:
"""
Increments the specified counter by the specified amount.
Expand Down Expand Up @@ -84,7 +84,7 @@ def del_counter(rse_id: str, account: InternalAccount, *, session: "Session") ->


@read_session
def get_updated_account_counters(total_workers: int, worker_number: int, *, session: "Session") -> list[dict[str, Any]]:
def get_updated_account_counters(total_workers: int, worker_number: int, *, session: "Session") -> list[RSEAccountCounterDict]:
"""
Get updated rse_counters.
Expand All @@ -102,7 +102,7 @@ def get_updated_account_counters(total_workers: int, worker_number: int, *, sess

query = filter_thread_work(session=session, query=query, total_threads=total_workers, thread_id=worker_number, hash_variable='CONCAT(account, rse_id)')

return [row._asdict() for row in session.execute(query).all()]
return [cast(RSEAccountCounterDict, row._asdict()) for row in session.execute(query).all()]


@transactional_session
Expand Down Expand Up @@ -148,7 +148,7 @@ def update_account_counter_history(account: InternalAccount, rse_id: str, *, ses


@transactional_session
def fill_account_counter_history_table(*, session: "Session"):
def fill_account_counter_history_table(*, session: "Session") -> None:
"""
Make a snapshot of current counters
Expand Down
16 changes: 8 additions & 8 deletions lib/rucio/core/account_limit.py
Expand Up @@ -20,7 +20,7 @@
from sqlalchemy.sql import func, select, literal
from sqlalchemy.sql.expression import and_, or_

from rucio.common.types import InternalAccount
from rucio.common.types import InternalAccount, RSEAccountUsageDict, RSEGlobalAccountUsageDict, RSELocalAccountUsageDict, RSEResolvedGlobalAccountLimitDict
from rucio.core.account import get_all_rse_usages_per_account
from rucio.core.rse import get_rse_name
from rucio.core.rse_expression_parser import parse_expression
Expand All @@ -32,7 +32,7 @@


@read_session
def get_rse_account_usage(rse_id: str, *, session: "Session") -> list[dict]:
def get_rse_account_usage(rse_id: str, *, session: "Session") -> list[RSEAccountUsageDict]:
"""
Returns the account limit and usage for all accounts on a RSE.
Expand Down Expand Up @@ -92,7 +92,7 @@ def get_rse_account_usage(rse_id: str, *, session: "Session") -> list[dict]:


@read_session
def get_global_account_limits(account: Optional[InternalAccount] = None, *, session: "Session") -> dict:
def get_global_account_limits(account: Optional[InternalAccount] = None, *, session: "Session") -> dict[str, RSEResolvedGlobalAccountLimitDict]:
"""
Returns the global account limits for the account.
Expand Down Expand Up @@ -164,7 +164,7 @@ def get_local_account_limit(account: InternalAccount, rse_id: str, *, session: "


@read_session
def get_local_account_limits(account: InternalAccount, rse_ids: Optional[list[str]] = None, *, session: "Session") -> dict:
def get_local_account_limits(account: InternalAccount, rse_ids: Optional[list[str]] = None, *, session: "Session") -> dict[uuid.UUID, int]:
"""
Returns the account limits for the account on the list of rses.
Expand Down Expand Up @@ -198,7 +198,7 @@ def get_local_account_limits(account: InternalAccount, rse_ids: Optional[list[st


@transactional_session
def set_local_account_limit(account: InternalAccount, rse_id: str, bytes_: int, *, session: "Session"):
def set_local_account_limit(account: InternalAccount, rse_id: str, bytes_: int, *, session: "Session") -> None:
"""
Returns the limits for the account on the rse.
Expand All @@ -216,7 +216,7 @@ def set_local_account_limit(account: InternalAccount, rse_id: str, bytes_: int,


@transactional_session
def set_global_account_limit(account: InternalAccount, rse_expression: str, bytes_: int, *, session: "Session"):
def set_global_account_limit(account: InternalAccount, rse_expression: str, bytes_: int, *, session: "Session") -> None:
"""
Sets the global limit for the account on a RSE expression.
Expand Down Expand Up @@ -270,7 +270,7 @@ def delete_global_account_limit(account: InternalAccount, rse_expression: str, *


@transactional_session
def get_local_account_usage(account: InternalAccount, rse_id: Optional[uuid.UUID] = None, *, session: "Session") -> list[dict]:
def get_local_account_usage(account: InternalAccount, rse_id: Optional[uuid.UUID] = None, *, session: "Session") -> list[RSELocalAccountUsageDict]:
"""
Read the account usage and connect it with (if available) the account limits of the account.
Expand Down Expand Up @@ -313,7 +313,7 @@ def get_local_account_usage(account: InternalAccount, rse_id: Optional[uuid.UUID


@transactional_session
def get_global_account_usage(account: InternalAccount, rse_expression: Optional[str] = None, *, session: "Session") -> list[dict]:
def get_global_account_usage(account: InternalAccount, rse_expression: Optional[str] = None, *, session: "Session") -> list[RSEGlobalAccountUsageDict]:
"""
Read the account usage and connect it with the global account limits of the account.
Expand Down

0 comments on commit 5ce0642

Please sign in to comment.