diff --git a/lib/rucio/common/types.py b/lib/rucio/common/types.py index 849d89f6c0..4154814562 100644 --- a/lib/rucio/common/types.py +++ b/lib/rucio/common/types.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from uuid import UUID from datetime import datetime from typing import Any, Callable, Optional, TypedDict, Union +from rucio.db.sqla.constants import AccountType, IdentityType + class InternalType(object): ''' @@ -162,6 +165,87 @@ class RSESettingsDict(TypedDict): protocols: list[RSEProtocolDict] +class RSEAccountCounterDict(TypedDict): + account: InternalAccount + rse_id: UUID + + +class RSEAccountUsageDict(TypedDict): + rse_id: UUID + rse: str + account: InternalAccount + used_files: int + used_bytes: int + quota_bytes: int + + +class RSEGlobalAccountUsageDict(TypedDict): + rse_expression: str + bytes: int + files: int + bytes_limit: int + bytes_remaining: int + + +class RSELocalAccountUsageDict(TypedDict): + rse_id: UUID + rse: str + bytes: int + files: int + bytes_limit: int + bytes_remaining: int + + +class RSEResolvedGlobalAccountLimitDict(TypedDict): + resolved_rses: str + resolved_rse_ids: list[UUID] + limit: float + + class TokenDict(TypedDict): token: str expires_at: datetime + + +class TokenValidationDict(TypedDict): + account: InternalAccount + identity: str + lifetime: datetime + audience: str + authz_scope: str + + +class AccountDict(TypedDict): + account: InternalAccount + type: AccountType + email: str + + +class AccountAttributesDict(TypedDict): + key: str + value: Union[bool, str] + + +class IdentityDict(TypedDict): + type: IdentityType + identity: str + email: str + + +class UsageDict(TypedDict): + bytes: int + files: int + updated_at: Optional[datetime] + + +class TokenOIDCAutoDict(TypedDict): + webhome: Optional[str] + token: Optional[TokenDict] + + +class TokenOIDCNoAutoDict(TypedDict): + fetchcode: str + + +class TokenOIDCPollingDict(TypedDict): + polling: bool diff --git a/lib/rucio/core/account.py b/lib/rucio/core/account.py index 1015b5d016..349e9c9412 100644 --- a/lib/rucio/core/account.py +++ b/lib/rucio/core/account.py @@ -27,7 +27,7 @@ import rucio.core.rse from rucio.common import exception from rucio.common.config import config_get_bool -from rucio.common.types import InternalAccount +from rucio.common.types import InternalAccount, AccountAttributesDict, AccountDict, IdentityDict, UsageDict from rucio.core.vo import vo_exists from rucio.db.sqla import models from rucio.db.sqla.constants import AccountStatus, AccountType @@ -138,22 +138,22 @@ def update_account(account: InternalAccount, key: str, value: Any, *, session: " models.Account.account == account ) try: - account_result = session.execute(query).scalar_one() - if key == 'status': - if isinstance(value, str): - value = AccountStatus[value] - if value == AccountStatus.SUSPENDED: - account_result.update({'status': value, 'suspended_at': datetime.utcnow()}) - elif value == AccountStatus.ACTIVE: - account_result.update({'status': value, 'suspended_at': None}) - else: - account_result.update({key: value}) + query_result = session.execute(query).scalar_one() except exc.NoResultFound: raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account) + if key == 'status': + if isinstance(value, str): + value = AccountStatus[value] + if value == AccountStatus.SUSPENDED: + query_result.update({'status': value, 'suspended_at': datetime.utcnow()}) + elif value == AccountStatus.ACTIVE: + query_result.update({'status': value, 'suspended_at': None}) + else: + query_result.update({key: value}) @stream_session -def list_accounts(filter_: Optional[dict[str, Any]] = None, *, session: "Session") -> Iterator[dict]: +def list_accounts(filter_: Optional[dict[str, Any]] = None, *, session: "Session") -> Iterator[AccountDict]: """ Returns a list of all account names. :param filter_: Dictionary of attributes by which the input data should be filtered @@ -214,7 +214,7 @@ def list_accounts(filter_: Optional[dict[str, Any]] = None, *, session: "Session @read_session -def list_identities(account: InternalAccount, *, session: "Session") -> list[dict[str, Any]]: +def list_identities(account: InternalAccount, *, session: "Session") -> list[IdentityDict]: """ List all identities on an account. @@ -249,7 +249,7 @@ def list_identities(account: InternalAccount, *, session: "Session") -> list[dic @read_session -def list_account_attributes(account: InternalAccount, *, session: "Session") -> list[dict[str, Any]]: +def list_account_attributes(account: InternalAccount, *, session: "Session") -> list[AccountAttributesDict]: """ Get all attributes defined for an account. @@ -356,7 +356,7 @@ def del_account_attribute(account: InternalAccount, key: str, *, session: "Sessi @read_session -def get_usage(rse_id: str, account: InternalAccount, *, session: "Session") -> dict: +def get_usage(rse_id: str, account: InternalAccount, *, session: "Session") -> UsageDict: """ Returns current values of the specified counter, or raises CounterNotFound if the counter does not exist. @@ -380,7 +380,7 @@ def get_usage(rse_id: str, account: InternalAccount, *, session: "Session") -> d @read_session -def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Session") -> list[dict]: +def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Session") -> list[models.AccountUsage.__dict__]: """ Returns current values of the specified counter, or raises CounterNotFound if the counter does not exist. @@ -401,7 +401,7 @@ def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Sessio @read_session -def get_usage_history(rse_id: str, account: InternalAccount, *, session: "Session") -> list[dict]: +def get_usage_history(rse_id: str, account: InternalAccount, *, session: "Session") -> list[models.AccountUsageHistory.__dict__]: """ Returns historical values of the specified counter, or raises CounterNotFound if the counter does not exist. diff --git a/lib/rucio/core/account_counter.py b/lib/rucio/core/account_counter.py index c986cf02af..e1bb3c355d 100644 --- a/lib/rucio/core/account_counter.py +++ b/lib/rucio/core/account_counter.py @@ -14,12 +14,12 @@ # limitations under the License. import datetime -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from sqlalchemy import literal, insert, select from sqlalchemy.orm.exc import NoResultFound -from rucio.common.types import InternalAccount +from rucio.common.types import InternalAccount, RSEAccountCounterDict from rucio.db.sqla import models, filter_thread_work from rucio.db.sqla.session import read_session, transactional_session @@ -30,7 +30,7 @@ @transactional_session -def add_counter(rse_id: str, account: InternalAccount, *, session: "Session"): +def add_counter(rse_id: str, account: InternalAccount, *, session: "Session") -> None: """ Creates the specified counter for a rse_id and account. @@ -43,7 +43,7 @@ def add_counter(rse_id: str, account: InternalAccount, *, session: "Session"): @transactional_session -def increase(rse_id: str, account: InternalAccount, files: int, bytes_: int, *, session: "Session"): +def increase(rse_id: str, account: InternalAccount, files: int, bytes_: int, *, session: "Session") -> None: """ Increments the specified counter by the specified amount. @@ -84,7 +84,7 @@ def del_counter(rse_id: str, account: InternalAccount, *, session: "Session") -> @read_session -def get_updated_account_counters(total_workers: int, worker_number: int, *, session: "Session") -> list[dict[str, Any]]: +def get_updated_account_counters(total_workers: int, worker_number: int, *, session: "Session") -> list[RSEAccountCounterDict]: """ Get updated rse_counters. @@ -148,7 +148,7 @@ def update_account_counter_history(account: InternalAccount, rse_id: str, *, ses @transactional_session -def fill_account_counter_history_table(*, session: "Session"): +def fill_account_counter_history_table(*, session: "Session") -> None: """ Make a snapshot of current counters diff --git a/lib/rucio/core/account_limit.py b/lib/rucio/core/account_limit.py index 22343cc661..3e0efcc4d7 100644 --- a/lib/rucio/core/account_limit.py +++ b/lib/rucio/core/account_limit.py @@ -20,7 +20,7 @@ from sqlalchemy.sql import func, select, literal from sqlalchemy.sql.expression import and_, or_ -from rucio.common.types import InternalAccount +from rucio.common.types import InternalAccount, RSEAccountUsageDict, RSEGlobalAccountUsageDict, RSELocalAccountUsageDict, RSEResolvedGlobalAccountLimitDict from rucio.core.account import get_all_rse_usages_per_account from rucio.core.rse import get_rse_name from rucio.core.rse_expression_parser import parse_expression @@ -32,7 +32,7 @@ @read_session -def get_rse_account_usage(rse_id: str, *, session: "Session") -> list[dict]: +def get_rse_account_usage(rse_id: str, *, session: "Session") -> list[RSEAccountUsageDict]: """ Returns the account limit and usage for all accounts on a RSE. @@ -92,7 +92,7 @@ def get_rse_account_usage(rse_id: str, *, session: "Session") -> list[dict]: @read_session -def get_global_account_limits(account: Optional[InternalAccount] = None, *, session: "Session") -> dict: +def get_global_account_limits(account: Optional[InternalAccount] = None, *, session: "Session") -> dict[str, RSEResolvedGlobalAccountLimitDict]: """ Returns the global account limits for the account. @@ -164,7 +164,7 @@ def get_local_account_limit(account: InternalAccount, rse_id: str, *, session: " @read_session -def get_local_account_limits(account: InternalAccount, rse_ids: Optional[list[str]] = None, *, session: "Session") -> dict: +def get_local_account_limits(account: InternalAccount, rse_ids: Optional[list[str]] = None, *, session: "Session") -> dict[uuid.UUID, int]: """ Returns the account limits for the account on the list of rses. @@ -198,7 +198,7 @@ def get_local_account_limits(account: InternalAccount, rse_ids: Optional[list[st @transactional_session -def set_local_account_limit(account: InternalAccount, rse_id: str, bytes_: int, *, session: "Session"): +def set_local_account_limit(account: InternalAccount, rse_id: str, bytes_: int, *, session: "Session") -> None: """ Returns the limits for the account on the rse. @@ -216,7 +216,7 @@ def set_local_account_limit(account: InternalAccount, rse_id: str, bytes_: int, @transactional_session -def set_global_account_limit(account: InternalAccount, rse_expression: str, bytes_: int, *, session: "Session"): +def set_global_account_limit(account: InternalAccount, rse_expression: str, bytes_: int, *, session: "Session") -> None: """ Sets the global limit for the account on a RSE expression. @@ -270,7 +270,7 @@ def delete_global_account_limit(account: InternalAccount, rse_expression: str, * @transactional_session -def get_local_account_usage(account: InternalAccount, rse_id: Optional[uuid.UUID] = None, *, session: "Session") -> list[dict]: +def get_local_account_usage(account: InternalAccount, rse_id: Optional[uuid.UUID] = None, *, session: "Session") -> list[RSELocalAccountUsageDict]: """ Read the account usage and connect it with (if available) the account limits of the account. @@ -313,7 +313,7 @@ def get_local_account_usage(account: InternalAccount, rse_id: Optional[uuid.UUID @transactional_session -def get_global_account_usage(account: InternalAccount, rse_expression: Optional[str] = None, *, session: "Session") -> list[dict]: +def get_global_account_usage(account: InternalAccount, rse_expression: Optional[str] = None, *, session: "Session") -> list[RSEGlobalAccountUsageDict]: """ Read the account usage and connect it with the global account limits of the account. diff --git a/lib/rucio/core/authentication.py b/lib/rucio/core/authentication.py index 31eb1e5b5b..9a626fc693 100644 --- a/lib/rucio/core/authentication.py +++ b/lib/rucio/core/authentication.py @@ -20,7 +20,7 @@ import sys import traceback from base64 import b64decode -from typing import TYPE_CHECKING, Optional, Union +from typing import Any, TYPE_CHECKING, Optional, Union import paramiko from dogpile.cache import make_region @@ -30,10 +30,10 @@ from rucio.common.cache import make_region_memcached from rucio.common.config import config_get_bool from rucio.common.exception import CannotAuthenticate, RucioException -from rucio.common.types import InternalAccount, TokenDict +from rucio.common.types import InternalAccount, TokenDict, TokenValidationDict from rucio.common.utils import chunks, generate_uuid, date_to_str from rucio.core.account import account_exists -from rucio.core.oidc import validate_jwt +from rucio.core.oidc import validate_jwt, token_dictionary from rucio.db.sqla import filter_thread_work from rucio.db.sqla import models from rucio.db.sqla.constants import IdentityType @@ -41,7 +41,6 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from typing import Any def strip_x509_proxy_attributes(dn: str) -> str: @@ -455,7 +454,7 @@ def delete_expired_tokens(total_workers: int, worker_number: int, limit: int = 1 @read_session -def query_token(token: str, *, session: "Session") -> Optional[dict]: +def query_token(token: str, *, session: "Session") -> Optional[TokenValidationDict]: """ Validate an authentication token using the database. This method will only be called if no entry could be found in the according cache. @@ -488,7 +487,7 @@ def query_token(token: str, *, session: "Session") -> Optional[dict]: @transactional_session -def validate_auth_token(token: str, *, session: "Session") -> "dict[str, Any]": +def validate_auth_token(token: str, *, session: "Session") -> TokenValidationDict: """ Validate an authentication token. @@ -510,7 +509,7 @@ def validate_auth_token(token: str, *, session: "Session") -> "dict[str, Any]": cache_key = token.replace(' ', '') # Check if token ca be found in cache region - value: "Union[NO_VALUE, dict[str, Any]]" = TOKENREGION.get(cache_key) + value: Union[NO_VALUE, dict[str, Any]] = TOKENREGION.get(cache_key) if value is NO_VALUE: # no cached entry found value = query_token(token, session=session) if not value: @@ -529,10 +528,6 @@ def validate_auth_token(token: str, *, session: "Session") -> "dict[str, Any]": return value -def token_dictionary(token: models.Token) -> TokenDict: - return {'token': token.token, 'expires_at': token.expired_at} - - @transactional_session def __delete_expired_tokens_account(account: InternalAccount, *, session: "Session") -> None: """" diff --git a/lib/rucio/core/oidc.py b/lib/rucio/core/oidc.py index 38b7cf6d94..b68ca6fba2 100644 --- a/lib/rucio/core/oidc.py +++ b/lib/rucio/core/oidc.py @@ -20,7 +20,7 @@ import subprocess import traceback from datetime import datetime, timedelta -from typing import Any, Final, TYPE_CHECKING, Optional +from typing import Any, Final, TYPE_CHECKING, Optional, Union from urllib.parse import urljoin, urlparse, parse_qs import requests @@ -43,7 +43,7 @@ from rucio.common.exception import (CannotAuthenticate, CannotAuthorize, RucioException) from rucio.common.stopwatch import Stopwatch -from rucio.common.types import InternalAccount, TokenDict +from rucio.common.types import InternalAccount, TokenDict, TokenValidationDict, TokenOIDCAutoDict, TokenOIDCNoAutoDict, TokenOIDCPollingDict from rucio.common.utils import all_oidc_req_claims_present, build_url, val_to_space_sep_str from rucio.core.account import account_exists from rucio.core.identity import exist_identity_account, get_default_account @@ -156,7 +156,7 @@ def request_token(audience: str, scope: str, use_cache: bool = True) -> Optional return token -def __get_rucio_oidc_clients(keytimeout: int = 43200) -> tuple[dict, dict]: +def __get_rucio_oidc_clients(keytimeout: int = 43200) -> tuple[dict[str, Client], dict[str, Client]]: """ Creates a Rucio OIDC Client instances per Identity Provider (IdP) according to etc/idpsecrets.json configuration file. @@ -469,7 +469,7 @@ def get_auth_oidc(account: str, *, session: "Session", **kwargs) -> str: @transactional_session -def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session") -> Optional[dict]: +def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session") -> Optional[Union[TokenOIDCNoAutoDict, TokenOIDCAutoDict, TokenOIDCPollingDict]]: """ After Rucio User got redirected to Rucio /auth/oidc_token (or /auth/oidc_code) REST endpoints with authz code and session state encoded within the URL. @@ -689,7 +689,7 @@ def __get_admin_token_oidc(account: InternalAccount, req_scope: str, req_audienc @read_session -def __get_admin_account_for_issuer(*, session: "Session") -> dict: +def __get_admin_account_for_issuer(*, session: "Session") -> dict[str, tuple[InternalAccount, str]]: """ Gets admin account for the IdP issuer :returns: dictionary { 'issuer_1': (account, identity), ... } """ @@ -1286,7 +1286,7 @@ def __get_keyvalues_from_claims(token: str, keys=None): @read_session -def __get_rucio_jwt_dict(jwt: str, account: Optional[InternalAccount] = None, *, session: "Session") -> Optional[dict]: +def __get_rucio_jwt_dict(jwt: str, account: Optional[InternalAccount] = None, *, session: "Session") -> Optional[TokenValidationDict]: """ Get a Rucio token dictionary from token claims. Check token expiration and find default Rucio @@ -1330,7 +1330,7 @@ def __get_rucio_jwt_dict(jwt: str, account: Optional[InternalAccount] = None, *, @transactional_session -def __save_validated_token(token: str, valid_dict: dict, extra_dict: Optional[dict] = None, *, session: "Session") -> TokenDict: +def __save_validated_token(token: str, valid_dict: TokenValidationDict, extra_dict: Optional[dict] = None, *, session: "Session") -> TokenDict: """ Save JWT token to the Rucio DB.