From 5ce06427413e422658bef45a9c6949a7b5270d2c Mon Sep 17 00:00:00 2001 From: rdimaio Date: Tue, 27 Feb 2024 18:53:00 +0100 Subject: [PATCH] Auth: Refactor to use custom dict types and fix type hints --- lib/rucio/common/types.py | 91 +++++++++++++++++++++++++++++++ lib/rucio/common/utils.py | 4 +- lib/rucio/core/account.py | 48 ++++++++-------- lib/rucio/core/account_counter.py | 14 ++--- lib/rucio/core/account_limit.py | 16 +++--- lib/rucio/core/authentication.py | 21 +++---- lib/rucio/core/oidc.py | 37 ++++++------- 7 files changed, 158 insertions(+), 73 deletions(-) diff --git a/lib/rucio/common/types.py b/lib/rucio/common/types.py index 849d89f6c0..55baf86ed3 100644 --- a/lib/rucio/common/types.py +++ b/lib/rucio/common/types.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from uuid import UUID from datetime import datetime from typing import Any, Callable, Optional, TypedDict, Union +from rucio.db.sqla.constants import AccountType, IdentityType + class InternalType(object): ''' @@ -162,6 +165,94 @@ class RSESettingsDict(TypedDict): protocols: list[RSEProtocolDict] +class RSEAccountCounterDict(TypedDict): + account: InternalAccount + rse_id: UUID + + +class RSEAccountUsageDict(TypedDict): + rse_id: UUID + rse: str + account: InternalAccount + used_files: int + used_bytes: int + quota_bytes: int + + +class RSEGlobalAccountUsageDict(TypedDict): + rse_expression: str + bytes: int + files: int + bytes_limit: int + bytes_remaining: int + + +class RSELocalAccountUsageDict(TypedDict): + rse_id: UUID + rse: str + bytes: int + files: int + bytes_limit: int + bytes_remaining: int + + +class RSEResolvedGlobalAccountLimitDict(TypedDict): + resolved_rses: str + resolved_rse_ids: list[UUID] + limit: float + + class TokenDict(TypedDict): token: str expires_at: datetime + + +class TokenValidationDict(TypedDict): + account: Optional[InternalAccount] + identity: Optional[str] + lifetime: Optional[datetime] + audience: Optional[str] + authz_scope: Optional[str] + + +class AccountDict(TypedDict): + account: InternalAccount + type: AccountType + email: str + + +class AccountAttributesDict(TypedDict): + key: str + value: Union[bool, str] + + +class IdentityDict(TypedDict): + type: IdentityType + identity: str + email: str + + +class UsageDict(TypedDict): + bytes: int + files: int + updated_at: Optional[datetime] + + +class TokenOIDCAutoDict(TypedDict, total=False): + webhome: str + token: TokenDict + + +class TokenOIDCNoAutoDict(TypedDict): + fetchcode: str + + +class TokenOIDCPollingDict(TypedDict): + polling: bool + + +class AccountUsageModelDict(TypedDict): + account: InternalAccount + rse_id: UUID + files: int + bytes: int diff --git a/lib/rucio/common/utils.py b/lib/rucio/common/utils.py index 60ea7bdb7c..7b3a2e803d 100644 --- a/lib/rucio/common/utils.py +++ b/lib/rucio/common/utils.py @@ -40,7 +40,7 @@ from functools import partial, wraps from io import StringIO from itertools import zip_longest -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union from urllib.parse import urlparse, urlencode, quote, parse_qsl, urlunparse from uuid import uuid4 as uuid from xml.etree import ElementTree @@ -1200,7 +1200,7 @@ def detect_client_location(): 'longitude': longitude} -def ssh_sign(private_key: str, message: str) -> str: +def ssh_sign(private_key: str, message: Union[str, bytes]) -> str: """ Sign a string message using the private key. diff --git a/lib/rucio/core/account.py b/lib/rucio/core/account.py index 1015b5d016..528f88a18b 100644 --- a/lib/rucio/core/account.py +++ b/lib/rucio/core/account.py @@ -17,7 +17,7 @@ from enum import Enum from re import match from traceback import format_exc -from typing import TYPE_CHECKING, Any, Iterator, Optional +from typing import TYPE_CHECKING, Any, cast, Iterator, Optional from sqlalchemy import select, and_ from sqlalchemy.exc import IntegrityError @@ -27,7 +27,7 @@ import rucio.core.rse from rucio.common import exception from rucio.common.config import config_get_bool -from rucio.common.types import InternalAccount +from rucio.common.types import InternalAccount, AccountAttributesDict, AccountDict, AccountUsageModelDict, IdentityDict, UsageDict from rucio.core.vo import vo_exists from rucio.db.sqla import models from rucio.db.sqla.constants import AccountStatus, AccountType @@ -138,22 +138,22 @@ def update_account(account: InternalAccount, key: str, value: Any, *, session: " models.Account.account == account ) try: - account_result = session.execute(query).scalar_one() - if key == 'status': - if isinstance(value, str): - value = AccountStatus[value] - if value == AccountStatus.SUSPENDED: - account_result.update({'status': value, 'suspended_at': datetime.utcnow()}) - elif value == AccountStatus.ACTIVE: - account_result.update({'status': value, 'suspended_at': None}) - else: - account_result.update({key: value}) + query_result = session.execute(query).scalar_one() except exc.NoResultFound: raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account) + if key == 'status': + if isinstance(value, str): + value = AccountStatus[value] + if value == AccountStatus.SUSPENDED: + query_result.update({'status': value, 'suspended_at': datetime.utcnow()}) + elif value == AccountStatus.ACTIVE: + query_result.update({'status': value, 'suspended_at': None}) + else: + query_result.update({key: value}) @stream_session -def list_accounts(filter_: Optional[dict[str, Any]] = None, *, session: "Session") -> Iterator[dict]: +def list_accounts(filter_: Optional[dict[str, Any]] = None, *, session: "Session") -> Iterator[AccountDict]: """ Returns a list of all account names. :param filter_: Dictionary of attributes by which the input data should be filtered @@ -214,7 +214,7 @@ def list_accounts(filter_: Optional[dict[str, Any]] = None, *, session: "Session @read_session -def list_identities(account: InternalAccount, *, session: "Session") -> list[dict[str, Any]]: +def list_identities(account: InternalAccount, *, session: "Session") -> list[IdentityDict]: """ List all identities on an account. @@ -245,11 +245,11 @@ def list_identities(account: InternalAccount, *, session: "Session") -> list[dic ).where( models.IdentityAccountAssociation.account == account ) - return [row._asdict() for row in session.execute(query)] + return [cast(IdentityDict, row._asdict()) for row in session.execute(query)] @read_session -def list_account_attributes(account: InternalAccount, *, session: "Session") -> list[dict[str, Any]]: +def list_account_attributes(account: InternalAccount, *, session: "Session") -> list[AccountAttributesDict]: """ Get all attributes defined for an account. @@ -275,7 +275,7 @@ def list_account_attributes(account: InternalAccount, *, session: "Session") -> ).where( models.AccountAttrAssociation.account == account ) - return [row._asdict() for row in session.execute(query)] + return [cast(AccountAttributesDict, row._asdict()) for row in session.execute(query)] @read_session @@ -356,7 +356,7 @@ def del_account_attribute(account: InternalAccount, key: str, *, session: "Sessi @read_session -def get_usage(rse_id: str, account: InternalAccount, *, session: "Session") -> dict: +def get_usage(rse_id: str, account: InternalAccount, *, session: "Session") -> UsageDict: """ Returns current values of the specified counter, or raises CounterNotFound if the counter does not exist. @@ -374,13 +374,13 @@ def get_usage(rse_id: str, account: InternalAccount, *, session: "Session") -> d models.AccountUsage.account == account ) try: - return session.execute(query).one()._asdict() + return cast(UsageDict, session.execute(query).one()._asdict()) except exc.NoResultFound: - return {'bytes': 0, 'files': 0, 'updated_at': None} + return UsageDict({'bytes': 0, 'files': 0, 'updated_at': None}) @read_session -def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Session") -> list[dict]: +def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Session") -> list[AccountUsageModelDict]: """ Returns current values of the specified counter, or raises CounterNotFound if the counter does not exist. @@ -395,13 +395,13 @@ def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Sessio models.AccountUsage.account == account ) try: - return [result.to_dict() for result in session.execute(query).scalars()] + return [cast(AccountUsageModelDict, result.to_dict()) for result in session.execute(query).scalars()] except exc.NoResultFound: return [] @read_session -def get_usage_history(rse_id: str, account: InternalAccount, *, session: "Session") -> list[dict]: +def get_usage_history(rse_id: str, account: InternalAccount, *, session: "Session") -> list[UsageDict]: """ Returns historical values of the specified counter, or raises CounterNotFound if the counter does not exist. @@ -421,6 +421,6 @@ def get_usage_history(rse_id: str, account: InternalAccount, *, session: "Sessio models.AccountUsageHistory.updated_at ) try: - return [row._asdict() for row in session.execute(query)] + return [cast(UsageDict, row._asdict()) for row in session.execute(query)] except exc.NoResultFound: raise exception.CounterNotFound('No usage can be found for account %s on RSE %s' % (account, rucio.core.rse.get_rse_name(rse_id=rse_id, session=session))) diff --git a/lib/rucio/core/account_counter.py b/lib/rucio/core/account_counter.py index c986cf02af..fb21653b20 100644 --- a/lib/rucio/core/account_counter.py +++ b/lib/rucio/core/account_counter.py @@ -14,12 +14,12 @@ # limitations under the License. import datetime -from typing import TYPE_CHECKING, Any +from typing import cast, TYPE_CHECKING from sqlalchemy import literal, insert, select from sqlalchemy.orm.exc import NoResultFound -from rucio.common.types import InternalAccount +from rucio.common.types import InternalAccount, RSEAccountCounterDict from rucio.db.sqla import models, filter_thread_work from rucio.db.sqla.session import read_session, transactional_session @@ -30,7 +30,7 @@ @transactional_session -def add_counter(rse_id: str, account: InternalAccount, *, session: "Session"): +def add_counter(rse_id: str, account: InternalAccount, *, session: "Session") -> None: """ Creates the specified counter for a rse_id and account. @@ -43,7 +43,7 @@ def add_counter(rse_id: str, account: InternalAccount, *, session: "Session"): @transactional_session -def increase(rse_id: str, account: InternalAccount, files: int, bytes_: int, *, session: "Session"): +def increase(rse_id: str, account: InternalAccount, files: int, bytes_: int, *, session: "Session") -> None: """ Increments the specified counter by the specified amount. @@ -84,7 +84,7 @@ def del_counter(rse_id: str, account: InternalAccount, *, session: "Session") -> @read_session -def get_updated_account_counters(total_workers: int, worker_number: int, *, session: "Session") -> list[dict[str, Any]]: +def get_updated_account_counters(total_workers: int, worker_number: int, *, session: "Session") -> list[RSEAccountCounterDict]: """ Get updated rse_counters. @@ -102,7 +102,7 @@ def get_updated_account_counters(total_workers: int, worker_number: int, *, sess query = filter_thread_work(session=session, query=query, total_threads=total_workers, thread_id=worker_number, hash_variable='CONCAT(account, rse_id)') - return [row._asdict() for row in session.execute(query).all()] + return [cast(RSEAccountCounterDict, row._asdict()) for row in session.execute(query).all()] @transactional_session @@ -148,7 +148,7 @@ def update_account_counter_history(account: InternalAccount, rse_id: str, *, ses @transactional_session -def fill_account_counter_history_table(*, session: "Session"): +def fill_account_counter_history_table(*, session: "Session") -> None: """ Make a snapshot of current counters diff --git a/lib/rucio/core/account_limit.py b/lib/rucio/core/account_limit.py index 22343cc661..3e0efcc4d7 100644 --- a/lib/rucio/core/account_limit.py +++ b/lib/rucio/core/account_limit.py @@ -20,7 +20,7 @@ from sqlalchemy.sql import func, select, literal from sqlalchemy.sql.expression import and_, or_ -from rucio.common.types import InternalAccount +from rucio.common.types import InternalAccount, RSEAccountUsageDict, RSEGlobalAccountUsageDict, RSELocalAccountUsageDict, RSEResolvedGlobalAccountLimitDict from rucio.core.account import get_all_rse_usages_per_account from rucio.core.rse import get_rse_name from rucio.core.rse_expression_parser import parse_expression @@ -32,7 +32,7 @@ @read_session -def get_rse_account_usage(rse_id: str, *, session: "Session") -> list[dict]: +def get_rse_account_usage(rse_id: str, *, session: "Session") -> list[RSEAccountUsageDict]: """ Returns the account limit and usage for all accounts on a RSE. @@ -92,7 +92,7 @@ def get_rse_account_usage(rse_id: str, *, session: "Session") -> list[dict]: @read_session -def get_global_account_limits(account: Optional[InternalAccount] = None, *, session: "Session") -> dict: +def get_global_account_limits(account: Optional[InternalAccount] = None, *, session: "Session") -> dict[str, RSEResolvedGlobalAccountLimitDict]: """ Returns the global account limits for the account. @@ -164,7 +164,7 @@ def get_local_account_limit(account: InternalAccount, rse_id: str, *, session: " @read_session -def get_local_account_limits(account: InternalAccount, rse_ids: Optional[list[str]] = None, *, session: "Session") -> dict: +def get_local_account_limits(account: InternalAccount, rse_ids: Optional[list[str]] = None, *, session: "Session") -> dict[uuid.UUID, int]: """ Returns the account limits for the account on the list of rses. @@ -198,7 +198,7 @@ def get_local_account_limits(account: InternalAccount, rse_ids: Optional[list[st @transactional_session -def set_local_account_limit(account: InternalAccount, rse_id: str, bytes_: int, *, session: "Session"): +def set_local_account_limit(account: InternalAccount, rse_id: str, bytes_: int, *, session: "Session") -> None: """ Returns the limits for the account on the rse. @@ -216,7 +216,7 @@ def set_local_account_limit(account: InternalAccount, rse_id: str, bytes_: int, @transactional_session -def set_global_account_limit(account: InternalAccount, rse_expression: str, bytes_: int, *, session: "Session"): +def set_global_account_limit(account: InternalAccount, rse_expression: str, bytes_: int, *, session: "Session") -> None: """ Sets the global limit for the account on a RSE expression. @@ -270,7 +270,7 @@ def delete_global_account_limit(account: InternalAccount, rse_expression: str, * @transactional_session -def get_local_account_usage(account: InternalAccount, rse_id: Optional[uuid.UUID] = None, *, session: "Session") -> list[dict]: +def get_local_account_usage(account: InternalAccount, rse_id: Optional[uuid.UUID] = None, *, session: "Session") -> list[RSELocalAccountUsageDict]: """ Read the account usage and connect it with (if available) the account limits of the account. @@ -313,7 +313,7 @@ def get_local_account_usage(account: InternalAccount, rse_id: Optional[uuid.UUID @transactional_session -def get_global_account_usage(account: InternalAccount, rse_expression: Optional[str] = None, *, session: "Session") -> list[dict]: +def get_global_account_usage(account: InternalAccount, rse_expression: Optional[str] = None, *, session: "Session") -> list[RSEGlobalAccountUsageDict]: """ Read the account usage and connect it with the global account limits of the account. diff --git a/lib/rucio/core/authentication.py b/lib/rucio/core/authentication.py index 31eb1e5b5b..2da4f04f87 100644 --- a/lib/rucio/core/authentication.py +++ b/lib/rucio/core/authentication.py @@ -20,7 +20,7 @@ import sys import traceback from base64 import b64decode -from typing import TYPE_CHECKING, Optional, Union +from typing import Any, cast, TYPE_CHECKING, Optional, Union import paramiko from dogpile.cache import make_region @@ -30,10 +30,10 @@ from rucio.common.cache import make_region_memcached from rucio.common.config import config_get_bool from rucio.common.exception import CannotAuthenticate, RucioException -from rucio.common.types import InternalAccount, TokenDict +from rucio.common.types import InternalAccount, TokenDict, TokenValidationDict from rucio.common.utils import chunks, generate_uuid, date_to_str from rucio.core.account import account_exists -from rucio.core.oidc import validate_jwt +from rucio.core.oidc import validate_jwt, token_dictionary from rucio.db.sqla import filter_thread_work from rucio.db.sqla import models from rucio.db.sqla.constants import IdentityType @@ -41,7 +41,6 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from typing import Any def strip_x509_proxy_attributes(dn: str) -> str: @@ -455,7 +454,7 @@ def delete_expired_tokens(total_workers: int, worker_number: int, limit: int = 1 @read_session -def query_token(token: str, *, session: "Session") -> Optional[dict]: +def query_token(token: str, *, session: "Session") -> Optional[TokenValidationDict]: """ Validate an authentication token using the database. This method will only be called if no entry could be found in the according cache. @@ -483,12 +482,12 @@ def query_token(token: str, *, session: "Session") -> Optional[dict]: ) result = session.execute(query).first() if result: - return result._asdict() + return cast(TokenValidationDict, result._asdict()) return None @transactional_session -def validate_auth_token(token: str, *, session: "Session") -> "dict[str, Any]": +def validate_auth_token(token: str, *, session: "Session") -> TokenValidationDict: """ Validate an authentication token. @@ -510,7 +509,7 @@ def validate_auth_token(token: str, *, session: "Session") -> "dict[str, Any]": cache_key = token.replace(' ', '') # Check if token ca be found in cache region - value: "Union[NO_VALUE, dict[str, Any]]" = TOKENREGION.get(cache_key) + value: Union[NO_VALUE, dict[str, Any]] = TOKENREGION.get(cache_key) if value is NO_VALUE: # no cached entry found value = query_token(token, session=session) if not value: @@ -526,11 +525,7 @@ def validate_auth_token(token: str, *, session: "Session") -> "dict[str, Any]": if lifetime < datetime.datetime.utcnow(): # check if expired TOKENREGION.delete(cache_key) raise CannotAuthenticate(f"Token found but expired since {date_to_str(lifetime)}.") - return value - - -def token_dictionary(token: models.Token) -> TokenDict: - return {'token': token.token, 'expires_at': token.expired_at} + return cast(TokenValidationDict, value) @transactional_session diff --git a/lib/rucio/core/oidc.py b/lib/rucio/core/oidc.py index 38b7cf6d94..c6618325b9 100644 --- a/lib/rucio/core/oidc.py +++ b/lib/rucio/core/oidc.py @@ -20,7 +20,7 @@ import subprocess import traceback from datetime import datetime, timedelta -from typing import Any, Final, TYPE_CHECKING, Optional +from typing import Any, Final, TYPE_CHECKING, Optional, Union from urllib.parse import urljoin, urlparse, parse_qs import requests @@ -43,7 +43,7 @@ from rucio.common.exception import (CannotAuthenticate, CannotAuthorize, RucioException) from rucio.common.stopwatch import Stopwatch -from rucio.common.types import InternalAccount, TokenDict +from rucio.common.types import InternalAccount, TokenDict, TokenValidationDict, TokenOIDCAutoDict, TokenOIDCNoAutoDict, TokenOIDCPollingDict from rucio.common.utils import all_oidc_req_claims_present, build_url, val_to_space_sep_str from rucio.core.account import account_exists from rucio.core.identity import exist_identity_account, get_default_account @@ -156,7 +156,7 @@ def request_token(audience: str, scope: str, use_cache: bool = True) -> Optional return token -def __get_rucio_oidc_clients(keytimeout: int = 43200) -> tuple[dict, dict]: +def __get_rucio_oidc_clients(keytimeout: int = 43200) -> tuple[dict[str, Client], dict[str, Client]]: """ Creates a Rucio OIDC Client instances per Identity Provider (IdP) according to etc/idpsecrets.json configuration file. @@ -469,7 +469,7 @@ def get_auth_oidc(account: str, *, session: "Session", **kwargs) -> str: @transactional_session -def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session") -> Optional[dict]: +def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session") -> Optional[Union[TokenOIDCNoAutoDict, TokenOIDCAutoDict, TokenOIDCPollingDict]]: """ After Rucio User got redirected to Rucio /auth/oidc_token (or /auth/oidc_code) REST endpoints with authz code and session state encoded within the URL. @@ -530,7 +530,7 @@ def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session" try: jwt_row_dict['account'] = get_default_account(jwt_row_dict['identity'], IdentityType.OIDC, True, session=session) except Exception: - return {'webhome': None, 'token': None} + return TokenOIDCAutoDict({'webhome': None, 'token': None}) # check if given account has the identity registered if not exist_identity_account(jwt_row_dict['identity'], IdentityType.OIDC, jwt_row_dict['account'], session=session): @@ -612,14 +612,14 @@ def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session" session.commit() METRICS.timer('IdP_authorization').observe(stopwatch.elapsed) if '_polling' in oauth_req_params.access_msg: - return {'polling': True} + return TokenOIDCPollingDict({'polling': True}) elif 'http' in oauth_req_params.access_msg: - return {'webhome': oauth_req_params.access_msg, 'token': new_token} + return TokenOIDCAutoDict({'webhome': oauth_req_params.access_msg, 'token': new_token}) else: - return {'fetchcode': fetchcode} + return TokenOIDCNoAutoDict({'fetchcode': fetchcode}) else: METRICS.timer('IdP_authorization').observe(stopwatch.elapsed) - return {'token': new_token} + return TokenOIDCAutoDict({'token': new_token}) except Exception: # TO-DO catch different exceptions - InvalidGrant etc. ... @@ -689,7 +689,7 @@ def __get_admin_token_oidc(account: InternalAccount, req_scope: str, req_audienc @read_session -def __get_admin_account_for_issuer(*, session: "Session") -> dict: +def __get_admin_account_for_issuer(*, session: "Session") -> dict[str, tuple[InternalAccount, str]]: """ Gets admin account for the IdP issuer :returns: dictionary { 'issuer_1': (account, identity), ... } """ @@ -1286,7 +1286,7 @@ def __get_keyvalues_from_claims(token: str, keys=None): @read_session -def __get_rucio_jwt_dict(jwt: str, account: Optional[InternalAccount] = None, *, session: "Session") -> Optional[dict]: +def __get_rucio_jwt_dict(jwt: str, account: Optional[InternalAccount] = None, *, session: "Session") -> Optional[TokenValidationDict]: """ Get a Rucio token dictionary from token claims. Check token expiration and find default Rucio @@ -1318,19 +1318,18 @@ def __get_rucio_jwt_dict(jwt: str, account: Optional[InternalAccount] = None, *, if not exist_identity_account(identity_string, IdentityType.OIDC, account, session=session): logging.debug("No OIDC identity exists for account: %s", str(account)) return None - value = {'account': account, - 'identity': identity_string, - 'lifetime': expiry_date, - 'audience': audience, - 'authz_scope': scope} - return value + return TokenValidationDict({'account': account, + 'identity': identity_string, + 'lifetime': expiry_date, + 'audience': audience, + 'authz_scope': scope}) except Exception: logging.debug(traceback.format_exc()) return None @transactional_session -def __save_validated_token(token: str, valid_dict: dict, extra_dict: Optional[dict] = None, *, session: "Session") -> TokenDict: +def __save_validated_token(token: str, valid_dict: TokenValidationDict, extra_dict: Optional[dict] = None, *, session: "Session") -> TokenDict: """ Save JWT token to the Rucio DB. @@ -1405,7 +1404,7 @@ def validate_jwt(json_web_token: str, *, session: "Session") -> dict[str, Any]: clprocess = subprocess.Popen(['curl', '-s', '-L', '-u', '%s:%s' % (oidc_client.client_id, oidc_client.client_secret), '-d', 'token=%s' % (json_web_token), - oidc_client.introspection_endpoint], + oidc_client.introspection_endpoint], # type: ignore shell=False, stdout=subprocess.PIPE) inspect_claims = json.loads(clprocess.communicate()[0]) try: