From a1b8b4bb39f7d8e966c44f366f5014fe6a44f375 Mon Sep 17 00:00:00 2001 From: rdimaio Date: Thu, 7 Mar 2024 18:46:05 +0100 Subject: [PATCH] Typing: Add type hints; #6454 --- lib/rucio/common/utils.py | 4 +- lib/rucio/core/account.py | 58 +++++++++++++------------- lib/rucio/core/account_counter.py | 20 ++++----- lib/rucio/core/account_limit.py | 31 ++++++++------ lib/rucio/core/authentication.py | 30 ++++++-------- lib/rucio/core/oidc.py | 67 +++++++++++++++---------------- 6 files changed, 106 insertions(+), 104 deletions(-) diff --git a/lib/rucio/common/utils.py b/lib/rucio/common/utils.py index 7b3b6c69df..574e2d8cc1 100644 --- a/lib/rucio/common/utils.py +++ b/lib/rucio/common/utils.py @@ -40,7 +40,7 @@ from functools import partial, wraps from io import StringIO from itertools import zip_longest -from typing import TYPE_CHECKING, Callable, Optional +from typing import Any, TYPE_CHECKING, Callable, Optional from urllib.parse import urlparse, urlencode, quote, parse_qsl, urlunparse from uuid import uuid4 as uuid from xml.etree import ElementTree @@ -568,7 +568,7 @@ def str_to_date(string): return datetime.datetime.strptime(string, DATE_FORMAT) if string else None -def val_to_space_sep_str(vallist): +def val_to_space_sep_str(vallist: Any) -> str: """ Converts a list of values into a string of space separated values :param vallist: the list of values to to convert into string diff --git a/lib/rucio/core/account.py b/lib/rucio/core/account.py index b65d702e1b..528f88a18b 100644 --- a/lib/rucio/core/account.py +++ b/lib/rucio/core/account.py @@ -17,7 +17,7 @@ from enum import Enum from re import match from traceback import format_exc -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, cast, Iterator, Optional from sqlalchemy import select, and_ from sqlalchemy.exc import IntegrityError @@ -27,6 +27,7 @@ import rucio.core.rse from rucio.common import exception from rucio.common.config import config_get_bool +from rucio.common.types import InternalAccount, AccountAttributesDict, AccountDict, AccountUsageModelDict, IdentityDict, UsageDict from rucio.core.vo import vo_exists from rucio.db.sqla import models from rucio.db.sqla.constants import AccountStatus, AccountType @@ -37,7 +38,7 @@ @transactional_session -def add_account(account, type_, email, *, session: "Session"): +def add_account(account: InternalAccount, type_: AccountType, email: str, *, session: "Session") -> None: """ Add an account with the given account name and type. :param account: the name of the new account. @@ -63,7 +64,7 @@ def add_account(account, type_, email, *, session: "Session"): @read_session -def account_exists(account, *, session: "Session"): +def account_exists(account: InternalAccount, *, session: "Session") -> bool: """ Checks to see if account exists and is active. :param account: Name of the account. @@ -81,7 +82,7 @@ def account_exists(account, *, session: "Session"): @read_session -def get_account(account, *, session: "Session"): +def get_account(account: InternalAccount, *, session: "Session") -> models.Account: """ Returns an account for the given account name. :param account: the name of the account. @@ -102,7 +103,7 @@ def get_account(account, *, session: "Session"): @transactional_session -def del_account(account, *, session: "Session"): +def del_account(account: InternalAccount, *, session: "Session") -> None: """ Disable an account with the given account name. :param account: the account name. @@ -115,15 +116,15 @@ def del_account(account, *, session: "Session"): models.Account.status == AccountStatus.ACTIVE ) try: - account = session.execute(query).scalar_one() + account_result = session.execute(query).scalar_one() except exc.NoResultFound: raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account) - account.update({'status': AccountStatus.DELETED, 'deleted_at': datetime.utcnow()}) + account_result.update({'status': AccountStatus.DELETED, 'deleted_at': datetime.utcnow()}) @transactional_session -def update_account(account, key, value, *, session: "Session"): +def update_account(account: InternalAccount, key: str, value: Any, *, session: "Session") -> None: """ Update a property of an account. :param account: Name of the account. @@ -137,22 +138,22 @@ def update_account(account, key, value, *, session: "Session"): models.Account.account == account ) try: - account = session.execute(query).scalar_one() + query_result = session.execute(query).scalar_one() except exc.NoResultFound: raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account) if key == 'status': if isinstance(value, str): value = AccountStatus[value] if value == AccountStatus.SUSPENDED: - account.update({'status': value, 'suspended_at': datetime.utcnow()}) + query_result.update({'status': value, 'suspended_at': datetime.utcnow()}) elif value == AccountStatus.ACTIVE: - account.update({'status': value, 'suspended_at': None}) + query_result.update({'status': value, 'suspended_at': None}) else: - account.update({key: value}) + query_result.update({key: value}) @stream_session -def list_accounts(filter_=None, *, session: "Session"): +def list_accounts(filter_: Optional[dict[str, Any]] = None, *, session: "Session") -> Iterator[AccountDict]: """ Returns a list of all account names. :param filter_: Dictionary of attributes by which the input data should be filtered @@ -213,7 +214,7 @@ def list_accounts(filter_=None, *, session: "Session"): @read_session -def list_identities(account, *, session: "Session"): +def list_identities(account: InternalAccount, *, session: "Session") -> list[IdentityDict]: """ List all identities on an account. @@ -244,11 +245,11 @@ def list_identities(account, *, session: "Session"): ).where( models.IdentityAccountAssociation.account == account ) - return [row._asdict() for row in session.execute(query)] + return [cast(IdentityDict, row._asdict()) for row in session.execute(query)] @read_session -def list_account_attributes(account, *, session: "Session"): +def list_account_attributes(account: InternalAccount, *, session: "Session") -> list[AccountAttributesDict]: """ Get all attributes defined for an account. @@ -274,11 +275,11 @@ def list_account_attributes(account, *, session: "Session"): ).where( models.AccountAttrAssociation.account == account ) - return [row._asdict() for row in session.execute(query)] + return [cast(AccountAttributesDict, row._asdict()) for row in session.execute(query)] @read_session -def has_account_attribute(account, key, *, session: "Session"): +def has_account_attribute(account: InternalAccount, key: str, *, session: "Session") -> bool: """ Indicates whether the named key is present for the account. @@ -298,7 +299,7 @@ def has_account_attribute(account, key, *, session: "Session"): @transactional_session -def add_account_attribute(account, key, value, *, session: "Session"): +def add_account_attribute(account: InternalAccount, key: str, value: Any, *, session: "Session") -> None: """ Add an attribute for the given account name. @@ -334,7 +335,7 @@ def add_account_attribute(account, key, value, *, session: "Session"): @transactional_session -def del_account_attribute(account, key, *, session: "Session"): +def del_account_attribute(account: InternalAccount, key: str, *, session: "Session") -> None: """ Add an attribute for the given account name. @@ -355,7 +356,7 @@ def del_account_attribute(account, key, *, session: "Session"): @read_session -def get_usage(rse_id, account, *, session: "Session"): +def get_usage(rse_id: str, account: InternalAccount, *, session: "Session") -> UsageDict: """ Returns current values of the specified counter, or raises CounterNotFound if the counter does not exist. @@ -373,13 +374,13 @@ def get_usage(rse_id, account, *, session: "Session"): models.AccountUsage.account == account ) try: - return session.execute(query).one()._asdict() + return cast(UsageDict, session.execute(query).one()._asdict()) except exc.NoResultFound: - return {'bytes': 0, 'files': 0, 'updated_at': None} + return UsageDict({'bytes': 0, 'files': 0, 'updated_at': None}) @read_session -def get_all_rse_usages_per_account(account, *, session: "Session"): +def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Session") -> list[AccountUsageModelDict]: """ Returns current values of the specified counter, or raises CounterNotFound if the counter does not exist. @@ -394,20 +395,20 @@ def get_all_rse_usages_per_account(account, *, session: "Session"): models.AccountUsage.account == account ) try: - return [result.to_dict() for result in session.execute(query).scalars()] + return [cast(AccountUsageModelDict, result.to_dict()) for result in session.execute(query).scalars()] except exc.NoResultFound: return [] @read_session -def get_usage_history(rse_id, account, *, session: "Session"): +def get_usage_history(rse_id: str, account: InternalAccount, *, session: "Session") -> list[UsageDict]: """ Returns historical values of the specified counter, or raises CounterNotFound if the counter does not exist. :param rse_id: The id of the RSE. :param account: The account name. :param session: The database session in use. - :returns: A dictionary {'bytes', 'files', 'updated_at'} + :returns: A list of dictionaries {'bytes', 'files', 'updated_at'} """ query = select( models.AccountUsageHistory.bytes, @@ -420,7 +421,6 @@ def get_usage_history(rse_id, account, *, session: "Session"): models.AccountUsageHistory.updated_at ) try: - return [row._asdict() for row in session.execute(query)] + return [cast(UsageDict, row._asdict()) for row in session.execute(query)] except exc.NoResultFound: raise exception.CounterNotFound('No usage can be found for account %s on RSE %s' % (account, rucio.core.rse.get_rse_name(rse_id=rse_id, session=session))) - return [] diff --git a/lib/rucio/core/account_counter.py b/lib/rucio/core/account_counter.py index 34957a6549..fb21653b20 100644 --- a/lib/rucio/core/account_counter.py +++ b/lib/rucio/core/account_counter.py @@ -12,12 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import datetime -from typing import TYPE_CHECKING +from typing import cast, TYPE_CHECKING from sqlalchemy import literal, insert, select from sqlalchemy.orm.exc import NoResultFound +from rucio.common.types import InternalAccount, RSEAccountCounterDict from rucio.db.sqla import models, filter_thread_work from rucio.db.sqla.session import read_session, transactional_session @@ -28,7 +30,7 @@ @transactional_session -def add_counter(rse_id, account, *, session: "Session"): +def add_counter(rse_id: str, account: InternalAccount, *, session: "Session") -> None: """ Creates the specified counter for a rse_id and account. @@ -41,7 +43,7 @@ def add_counter(rse_id, account, *, session: "Session"): @transactional_session -def increase(rse_id, account, files, bytes_, *, session: "Session"): +def increase(rse_id: str, account: InternalAccount, files: int, bytes_: int, *, session: "Session") -> None: """ Increments the specified counter by the specified amount. @@ -55,7 +57,7 @@ def increase(rse_id, account, files, bytes_, *, session: "Session"): @transactional_session -def decrease(rse_id, account, files, bytes_, *, session: "Session"): +def decrease(rse_id: str, account: InternalAccount, files: int, bytes_: int, *, session: "Session") -> None: """ Decreases the specified counter by the specified amount. @@ -69,7 +71,7 @@ def decrease(rse_id, account, files, bytes_, *, session: "Session"): @transactional_session -def del_counter(rse_id, account, *, session: "Session"): +def del_counter(rse_id: str, account: InternalAccount, *, session: "Session") -> None: """ Resets the specified counter and initializes it by the specified amounts. @@ -82,7 +84,7 @@ def del_counter(rse_id, account, *, session: "Session"): @read_session -def get_updated_account_counters(total_workers, worker_number, *, session: "Session"): +def get_updated_account_counters(total_workers: int, worker_number: int, *, session: "Session") -> list[RSEAccountCounterDict]: """ Get updated rse_counters. @@ -104,7 +106,7 @@ def get_updated_account_counters(total_workers, worker_number, *, session: "Sess @transactional_session -def update_account_counter(account, rse_id, *, session: "Session"): +def update_account_counter(account: InternalAccount, rse_id: str, *, session: "Session") -> None: """ Read the updated_account_counters and update the account_counter. @@ -130,7 +132,7 @@ def update_account_counter(account, rse_id, *, session: "Session"): @transactional_session -def update_account_counter_history(account, rse_id, *, session: "Session"): +def update_account_counter_history(account: InternalAccount, rse_id: str, *, session: "Session") -> None: """ Read the AccountUsage and update the AccountUsageHistory. @@ -146,7 +148,7 @@ def update_account_counter_history(account, rse_id, *, session: "Session"): @transactional_session -def fill_account_counter_history_table(*, session: "Session"): +def fill_account_counter_history_table(*, session: "Session") -> None: """ Make a snapshot of current counters diff --git a/lib/rucio/core/account_limit.py b/lib/rucio/core/account_limit.py index fd7b84722d..2bb20208dc 100644 --- a/lib/rucio/core/account_limit.py +++ b/lib/rucio/core/account_limit.py @@ -13,12 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Union +from uuid import UUID from sqlalchemy.orm.exc import NoResultFound -from sqlalchemy.sql import func, select, literal +from sqlalchemy.sql import func, literal, select from sqlalchemy.sql.expression import and_, or_ +from rucio.common.types import (InternalAccount, RSEAccountUsageDict, + RSEGlobalAccountUsageDict, + RSELocalAccountUsageDict, + RSEResolvedGlobalAccountLimitDict) from rucio.core.account import get_all_rse_usages_per_account from rucio.core.rse import get_rse_name from rucio.core.rse_expression_parser import parse_expression @@ -30,7 +35,7 @@ @read_session -def get_rse_account_usage(rse_id, *, session: "Session"): +def get_rse_account_usage(rse_id: str, *, session: "Session") -> list[RSEAccountUsageDict]: """ Returns the account limit and usage for all accounts on a RSE. @@ -90,7 +95,7 @@ def get_rse_account_usage(rse_id, *, session: "Session"): @read_session -def get_global_account_limits(account=None, *, session: "Session"): +def get_global_account_limits(account: Optional[InternalAccount] = None, *, session: "Session") -> dict[str, RSEResolvedGlobalAccountLimitDict]: """ Returns the global account limits for the account. @@ -121,7 +126,7 @@ def get_global_account_limits(account=None, *, session: "Session"): @read_session -def get_global_account_limit(account, rse_expression, *, session: "Session"): +def get_global_account_limit(account: InternalAccount, rse_expression: str, *, session: "Session") -> Union[int, float, None]: """ Returns the global account limit for the account on the rse expression. @@ -141,7 +146,7 @@ def get_global_account_limit(account, rse_expression, *, session: "Session"): @read_session -def get_local_account_limit(account, rse_id, *, session: "Session"): +def get_local_account_limit(account: InternalAccount, rse_id: str, *, session: "Session") -> Union[int, float, None]: """ Returns the account limit for the account on the rse. @@ -162,7 +167,7 @@ def get_local_account_limit(account, rse_id, *, session: "Session"): @read_session -def get_local_account_limits(account, rse_ids=None, *, session: "Session"): +def get_local_account_limits(account: InternalAccount, rse_ids: Optional[list[str]] = None, *, session: "Session") -> dict[str, int]: """ Returns the account limits for the account on the list of rses. @@ -196,7 +201,7 @@ def get_local_account_limits(account, rse_ids=None, *, session: "Session"): @transactional_session -def set_local_account_limit(account, rse_id, bytes_, *, session: "Session"): +def set_local_account_limit(account: InternalAccount, rse_id: str, bytes_: int, *, session: "Session") -> None: """ Returns the limits for the account on the rse. @@ -214,7 +219,7 @@ def set_local_account_limit(account, rse_id, bytes_, *, session: "Session"): @transactional_session -def set_global_account_limit(account, rse_expression, bytes_, *, session: "Session"): +def set_global_account_limit(account: InternalAccount, rse_expression: str, bytes_: int, *, session: "Session") -> None: """ Sets the global limit for the account on a RSE expression. @@ -232,7 +237,7 @@ def set_global_account_limit(account, rse_expression, bytes_, *, session: "Sessi @transactional_session -def delete_local_account_limit(account, rse_id, *, session: "Session"): +def delete_local_account_limit(account: InternalAccount, rse_id: str, *, session: "Session") -> bool: """ Deletes a local account limit. @@ -250,7 +255,7 @@ def delete_local_account_limit(account, rse_id, *, session: "Session"): @transactional_session -def delete_global_account_limit(account, rse_expression, *, session: "Session"): +def delete_global_account_limit(account: InternalAccount, rse_expression: str, *, session: "Session") -> bool: """ Deletes a global account limit. @@ -268,7 +273,7 @@ def delete_global_account_limit(account, rse_expression, *, session: "Session"): @transactional_session -def get_local_account_usage(account, rse_id=None, *, session: "Session"): +def get_local_account_usage(account: InternalAccount, rse_id: Optional[UUID] = None, *, session: "Session") -> list[RSELocalAccountUsageDict]: """ Read the account usage and connect it with (if available) the account limits of the account. @@ -311,7 +316,7 @@ def get_local_account_usage(account, rse_id=None, *, session: "Session"): @transactional_session -def get_global_account_usage(account, rse_expression=None, *, session: "Session"): +def get_global_account_usage(account: InternalAccount, rse_expression: Optional[str] = None, *, session: "Session") -> list[RSEGlobalAccountUsageDict]: """ Read the account usage and connect it with the global account limits of the account. diff --git a/lib/rucio/core/authentication.py b/lib/rucio/core/authentication.py index 8ea891bf39..2da4f04f87 100644 --- a/lib/rucio/core/authentication.py +++ b/lib/rucio/core/authentication.py @@ -90,7 +90,7 @@ def generate_key(token, *, session: "Session"): @transactional_session -def get_auth_token_user_pass(account, username, password, appid, ip=None, *, session: "Session"): +def get_auth_token_user_pass(account: InternalAccount, username: str, password: str, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]: """ Authenticate a Rucio account temporarily via username and password. @@ -150,7 +150,7 @@ def get_auth_token_user_pass(account, username, password, appid, ip=None, *, ses @transactional_session -def get_auth_token_x509(account, dn, appid, ip=None, *, session: "Session"): +def get_auth_token_x509(account: InternalAccount, dn: str, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]: """ Authenticate a Rucio account temporarily via an x509 certificate. @@ -182,7 +182,7 @@ def get_auth_token_x509(account, dn, appid, ip=None, *, session: "Session"): @transactional_session -def get_auth_token_gss(account, gsstoken, appid, ip=None, *, session: "Session"): +def get_auth_token_gss(account: InternalAccount, gsstoken: str, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]: """ Authenticate a Rucio account temporarily via a GSS token. @@ -286,7 +286,7 @@ def get_auth_token_ssh(account: InternalAccount, signature: bytes, appid: str, i @transactional_session -def get_ssh_challenge_token(account, appid, ip=None, *, session: "Session"): +def get_ssh_challenge_token(account: InternalAccount, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]: """ Prepare a challenge token for subsequent SSH public key authentication. @@ -322,7 +322,7 @@ def get_ssh_challenge_token(account, appid, ip=None, *, session: "Session"): @transactional_session -def get_auth_token_saml(account, saml_nameid, appid, ip=None, *, session: "Session"): +def get_auth_token_saml(account: InternalAccount, saml_nameid: str, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]: """ Authenticate a Rucio account temporarily via SAML. @@ -353,7 +353,7 @@ def get_auth_token_saml(account, saml_nameid, appid, ip=None, *, session: "Sessi @transactional_session -def redirect_auth_oidc(auth_code, fetchtoken=False, *, session: "Session"): +def redirect_auth_oidc(auth_code: str, fetchtoken: bool = False, *, session: "Session") -> Optional[str]: """ Finds the Authentication URL in the Rucio DB oauth_requests table and redirects user's browser to this URL. @@ -394,7 +394,7 @@ def redirect_auth_oidc(auth_code, fetchtoken=False, *, session: "Session"): @transactional_session -def delete_expired_tokens(total_workers, worker_number, limit=1000, *, session: "Session"): +def delete_expired_tokens(total_workers: int, worker_number: int, limit: int = 1000, *, session: "Session") -> int: """ Delete expired tokens. @@ -454,7 +454,7 @@ def delete_expired_tokens(total_workers, worker_number, limit=1000, *, session: @read_session -def query_token(token, *, session: "Session"): +def query_token(token: str, *, session: "Session") -> Optional[TokenValidationDict]: """ Validate an authentication token using the database. This method will only be called if no entry could be found in the according cache. @@ -482,12 +482,12 @@ def query_token(token, *, session: "Session"): ) result = session.execute(query).first() if result: - return result._asdict() + return cast(TokenValidationDict, result._asdict()) return None @transactional_session -def validate_auth_token(token: str, *, session: "Session") -> "dict[str, Any]": +def validate_auth_token(token: str, *, session: "Session") -> TokenValidationDict: """ Validate an authentication token. @@ -509,7 +509,7 @@ def validate_auth_token(token: str, *, session: "Session") -> "dict[str, Any]": cache_key = token.replace(' ', '') # Check if token ca be found in cache region - value: "Union[NO_VALUE, dict[str, Any]]" = TOKENREGION.get(cache_key) + value: Union[NO_VALUE, dict[str, Any]] = TOKENREGION.get(cache_key) if value is NO_VALUE: # no cached entry found value = query_token(token, session=session) if not value: @@ -525,15 +525,11 @@ def validate_auth_token(token: str, *, session: "Session") -> "dict[str, Any]": if lifetime < datetime.datetime.utcnow(): # check if expired TOKENREGION.delete(cache_key) raise CannotAuthenticate(f"Token found but expired since {date_to_str(lifetime)}.") - return value - - -def token_dictionary(token: models.Token): - return {'token': token.token, 'expires_at': token.expired_at} + return cast(TokenValidationDict, value) @transactional_session -def __delete_expired_tokens_account(account, *, session: "Session"): +def __delete_expired_tokens_account(account: InternalAccount, *, session: "Session") -> None: """" Deletes expired tokens from the database. diff --git a/lib/rucio/core/oidc.py b/lib/rucio/core/oidc.py index 54c42a87b5..c6618325b9 100644 --- a/lib/rucio/core/oidc.py +++ b/lib/rucio/core/oidc.py @@ -20,7 +20,7 @@ import subprocess import traceback from datetime import datetime, timedelta -from typing import Any, Final, TYPE_CHECKING, Optional +from typing import Any, Final, TYPE_CHECKING, Optional, Union from urllib.parse import urljoin, urlparse, parse_qs import requests @@ -38,12 +38,12 @@ from sqlalchemy import delete, select, update from sqlalchemy.sql.expression import true -from rucio.common import types from rucio.common.cache import make_region_memcached from rucio.common.config import config_get, config_get_int from rucio.common.exception import (CannotAuthenticate, CannotAuthorize, RucioException) from rucio.common.stopwatch import Stopwatch +from rucio.common.types import InternalAccount, TokenDict, TokenValidationDict, TokenOIDCAutoDict, TokenOIDCNoAutoDict, TokenOIDCPollingDict from rucio.common.utils import all_oidc_req_claims_present, build_url, val_to_space_sep_str from rucio.core.account import account_exists from rucio.core.identity import exist_identity_account, get_default_account @@ -156,7 +156,7 @@ def request_token(audience: str, scope: str, use_cache: bool = True) -> Optional return token -def __get_rucio_oidc_clients(keytimeout: int = 43200) -> tuple[dict, dict]: +def __get_rucio_oidc_clients(keytimeout: int = 43200) -> tuple[dict[str, Client], dict[str, Client]]: """ Creates a Rucio OIDC Client instances per Identity Provider (IdP) according to etc/idpsecrets.json configuration file. @@ -469,7 +469,7 @@ def get_auth_oidc(account: str, *, session: "Session", **kwargs) -> str: @transactional_session -def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session"): +def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session") -> Optional[Union[TokenOIDCNoAutoDict, TokenOIDCAutoDict, TokenOIDCPollingDict]]: """ After Rucio User got redirected to Rucio /auth/oidc_token (or /auth/oidc_code) REST endpoints with authz code and session state encoded within the URL. @@ -479,8 +479,8 @@ def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session" :param ip: IP address of the client as a string. :param session: The database session in use. - :returns: One of the following tuples: ("fetchcode", ); ("token", ); - ("polling", True); The result depends on the authentication strategy being used + :returns: One of the following dicts: {"fetchcode": }; {"token": }; + {"polling": True}; The result depends on the authentication strategy being used (no auto, auto, polling). """ try: @@ -530,7 +530,7 @@ def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session" try: jwt_row_dict['account'] = get_default_account(jwt_row_dict['identity'], IdentityType.OIDC, True, session=session) except Exception: - return {'webhome': None, 'token': None} + return TokenOIDCAutoDict({'webhome': None, 'token': None}) # check if given account has the identity registered if not exist_identity_account(jwt_row_dict['identity'], IdentityType.OIDC, jwt_row_dict['account'], session=session): @@ -612,14 +612,14 @@ def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session" session.commit() METRICS.timer('IdP_authorization').observe(stopwatch.elapsed) if '_polling' in oauth_req_params.access_msg: - return {'polling': True} + return TokenOIDCPollingDict({'polling': True}) elif 'http' in oauth_req_params.access_msg: - return {'webhome': oauth_req_params.access_msg, 'token': new_token} + return TokenOIDCAutoDict({'webhome': oauth_req_params.access_msg, 'token': new_token}) else: - return {'fetchcode': fetchcode} + return TokenOIDCNoAutoDict({'fetchcode': fetchcode}) else: METRICS.timer('IdP_authorization').observe(stopwatch.elapsed) - return {'token': new_token} + return TokenOIDCAutoDict({'token': new_token}) except Exception: # TO-DO catch different exceptions - InvalidGrant etc. ... @@ -630,13 +630,13 @@ def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session" @transactional_session -def __get_admin_token_oidc(account: types.InternalAccount, req_scope, req_audience, issuer, *, session: "Session"): +def __get_admin_token_oidc(account: InternalAccount, req_scope: str, req_audience: str, issuer: str, *, session: "Session") -> Optional[TokenDict]: """ Get a token for Rucio application to act on behalf of itself. client_credential flow is used for this purpose. No refresh token is expected to be used. - :param account: the Rucio Admin account name to be used (InternalAccount object expected) + :param account: the Rucio Admin account name to be used :param req_scope: the audience requested for the Rucio client's token :param req_audience: the scope requested for the Rucio client's token :param issuer: the Identity Provider nickname or the Rucio instance in use @@ -689,9 +689,9 @@ def __get_admin_token_oidc(account: types.InternalAccount, req_scope, req_audien @read_session -def __get_admin_account_for_issuer(*, session: "Session"): +def __get_admin_account_for_issuer(*, session: "Session") -> dict[str, tuple[InternalAccount, str]]: """ Gets admin account for the IdP issuer - :returns : dictionary { 'issuer_1': (account, identity), ... } + :returns: dictionary { 'issuer_1': (account, identity), ... } """ if not OIDC_ADMIN_CLIENTS: @@ -715,7 +715,7 @@ def __get_admin_account_for_issuer(*, session: "Session"): @transactional_session -def get_token_for_account_operation(account: str, req_audience: str = None, req_scope: str = None, admin: bool = False, *, session: "Session"): +def get_token_for_account_operation(account: str, req_audience: str = None, req_scope: str = None, admin: bool = False, *, session: "Session") -> Optional[TokenDict]: """ Looks-up a JWT token with the required scope and audience claims with the account OIDC issuer. If tokens are found, and none contains the requested audience and scope a new token is requested @@ -1000,7 +1000,7 @@ def __change_refresh_state(token: str, refresh: bool = False, *, session: "Sessi @transactional_session -def refresh_cli_auth_token(token_string: str, account: str, *, session: "Session"): +def refresh_cli_auth_token(token_string: str, account: str, *, session: "Session") -> Optional[tuple[str, int]]: """ Checks if there is active refresh token and if so returns either active token with expiration timestamp or requests a new @@ -1079,7 +1079,7 @@ def refresh_cli_auth_token(token_string: str, account: str, *, session: "Session @transactional_session -def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int = 3600, limit: int = 1000, *, session: "Session"): +def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int = 3600, limit: int = 1000, *, session: "Session") -> int: """ Refreshes tokens which expired or will expire before (now + refreshrate) next run of this function and which have valid refresh token. @@ -1089,7 +1089,7 @@ def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int :param limit: Maximum number of tokens to refresh per call. :param session: Database session in use. - :return: numper of tokens refreshed + :return: number of tokens refreshed """ nrefreshed = 0 try: @@ -1135,7 +1135,7 @@ def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int @METRICS.time_it @transactional_session -def __refresh_token_oidc(token_object: models.Token, *, session: "Session"): +def __refresh_token_oidc(token_object: models.Token, *, session: "Session") -> Optional[TokenDict]: """ Requests new access and refresh tokens from the Identity Provider. Assumption: The Identity Provider issues refresh tokens for one time use only and @@ -1212,7 +1212,7 @@ def __refresh_token_oidc(token_object: models.Token, *, session: "Session"): @transactional_session -def delete_expired_oauthrequests(total_workers: int, worker_number: int, limit: int = 1000, *, session: "Session"): +def delete_expired_oauthrequests(total_workers: int, worker_number: int, limit: int = 1000, *, session: "Session") -> int: """ Delete expired OAuth request parameters. @@ -1266,7 +1266,7 @@ def __get_keyvalues_from_claims(token: str, keys=None): """ Extracting claims from token, e.g. scope and audience. :param token: the JWT to be unpacked - :param key: list of key names to extract from the token claims + :param keys: list of key names to extract from the token claims :returns: The list of unicode values under the key, throws an exception otherwise. """ @@ -1278,7 +1278,7 @@ def __get_keyvalues_from_claims(token: str, keys=None): for key in keys: value = '' if key in claims: - value = val_to_space_sep_str(claims[key]) + value = val_to_space_sep_str(claims[key]) # type: ignore resdict[key] = value return resdict except Exception as error: @@ -1286,7 +1286,7 @@ def __get_keyvalues_from_claims(token: str, keys=None): @read_session -def __get_rucio_jwt_dict(jwt: str, account=None, *, session: "Session"): +def __get_rucio_jwt_dict(jwt: str, account: Optional[InternalAccount] = None, *, session: "Session") -> Optional[TokenValidationDict]: """ Get a Rucio token dictionary from token claims. Check token expiration and find default Rucio @@ -1318,19 +1318,18 @@ def __get_rucio_jwt_dict(jwt: str, account=None, *, session: "Session"): if not exist_identity_account(identity_string, IdentityType.OIDC, account, session=session): logging.debug("No OIDC identity exists for account: %s", str(account)) return None - value = {'account': account, - 'identity': identity_string, - 'lifetime': expiry_date, - 'audience': audience, - 'authz_scope': scope} - return value + return TokenValidationDict({'account': account, + 'identity': identity_string, + 'lifetime': expiry_date, + 'audience': audience, + 'authz_scope': scope}) except Exception: logging.debug(traceback.format_exc()) return None @transactional_session -def __save_validated_token(token, valid_dict, extra_dict=None, *, session: "Session"): +def __save_validated_token(token: str, valid_dict: TokenValidationDict, extra_dict: Optional[dict] = None, *, session: "Session") -> TokenDict: """ Save JWT token to the Rucio DB. @@ -1405,7 +1404,7 @@ def validate_jwt(json_web_token: str, *, session: "Session") -> dict[str, Any]: clprocess = subprocess.Popen(['curl', '-s', '-L', '-u', '%s:%s' % (oidc_client.client_id, oidc_client.client_secret), '-d', 'token=%s' % (json_web_token), - oidc_client.introspection_endpoint], + oidc_client.introspection_endpoint], # type: ignore shell=False, stdout=subprocess.PIPE) inspect_claims = json.loads(clprocess.communicate()[0]) try: @@ -1433,7 +1432,7 @@ def validate_jwt(json_web_token: str, *, session: "Session") -> dict[str, Any]: raise CannotAuthenticate(traceback.format_exc()) -def oidc_identity_string(sub: str, iss: str): +def oidc_identity_string(sub: str, iss: str) -> str: """ Transform IdP sub claim and issuers url into users identity string. :param sub: users SUB claim from the Identity Provider @@ -1444,5 +1443,5 @@ def oidc_identity_string(sub: str, iss: str): return 'SUB=' + str(sub) + ', ISS=' + str(iss) -def token_dictionary(token: models.Token): +def token_dictionary(token: models.Token) -> TokenDict: return {'token': token.token, 'expires_at': token.expired_at}