diff --git a/lib/rucio/common/types.py b/lib/rucio/common/types.py index c3b0628dd8..849d89f6c0 100644 --- a/lib/rucio/common/types.py +++ b/lib/rucio/common/types.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from datetime import datetime from typing import Any, Callable, Optional, TypedDict, Union @@ -159,3 +160,8 @@ class RSESettingsDict(TypedDict): deterministic: bool domain: list[str] protocols: list[RSEProtocolDict] + + +class TokenDict(TypedDict): + token: str + expires_at: datetime diff --git a/lib/rucio/common/utils.py b/lib/rucio/common/utils.py index 34e7b8b013..a5d334cd33 100644 --- a/lib/rucio/common/utils.py +++ b/lib/rucio/common/utils.py @@ -40,7 +40,7 @@ from functools import partial, wraps from io import StringIO from itertools import zip_longest -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from urllib.parse import urlparse, urlencode, quote, parse_qsl, urlunparse from uuid import uuid4 as uuid from xml.etree import ElementTree @@ -570,7 +570,7 @@ def str_to_date(string): return datetime.datetime.strptime(string, DATE_FORMAT) if string else None -def val_to_space_sep_str(vallist): +def val_to_space_sep_str(vallist: Any) -> str: """ Converts a list of values into a string of space separated values :param vallist: the list of values to to convert into string diff --git a/lib/rucio/core/account.py b/lib/rucio/core/account.py index 2c05ff6dc8..1015b5d016 100644 --- a/lib/rucio/core/account.py +++ b/lib/rucio/core/account.py @@ -17,9 +17,7 @@ from enum import Enum from re import match from traceback import format_exc -from typing import TYPE_CHECKING, Any -from collections.abc import Generator -import uuid +from typing import TYPE_CHECKING, Any, Iterator, Optional from sqlalchemy import select, and_ from sqlalchemy.exc import IntegrityError @@ -40,7 +38,7 @@ @transactional_session -def add_account(account: InternalAccount, type_: AccountType, email: str, *, session: "Session"): +def add_account(account: InternalAccount, type_: AccountType, email: str, *, session: "Session") -> None: """ Add an account with the given account name and type. :param account: the name of the new account. @@ -84,7 +82,7 @@ def account_exists(account: InternalAccount, *, session: "Session") -> bool: @read_session -def get_account(account: InternalAccount, *, session: "Session") -> dict: +def get_account(account: InternalAccount, *, session: "Session") -> models.Account: """ Returns an account for the given account name. :param account: the name of the account. @@ -105,7 +103,7 @@ def get_account(account: InternalAccount, *, session: "Session") -> dict: @transactional_session -def del_account(account: InternalAccount, *, session: "Session"): +def del_account(account: InternalAccount, *, session: "Session") -> None: """ Disable an account with the given account name. :param account: the account name. @@ -118,15 +116,15 @@ def del_account(account: InternalAccount, *, session: "Session"): models.Account.status == AccountStatus.ACTIVE ) try: - account = session.execute(query).scalar_one() + account_result = session.execute(query).scalar_one() except exc.NoResultFound: raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account) - account.update({'status': AccountStatus.DELETED, 'deleted_at': datetime.utcnow()}) + account_result.update({'status': AccountStatus.DELETED, 'deleted_at': datetime.utcnow()}) @transactional_session -def update_account(account: InternalAccount, key: str, value: Any, *, session: "Session"): +def update_account(account: InternalAccount, key: str, value: Any, *, session: "Session") -> None: """ Update a property of an account. :param account: Name of the account. @@ -140,22 +138,22 @@ def update_account(account: InternalAccount, key: str, value: Any, *, session: " models.Account.account == account ) try: - account = session.execute(query).scalar_one() + account_result = session.execute(query).scalar_one() + if key == 'status': + if isinstance(value, str): + value = AccountStatus[value] + if value == AccountStatus.SUSPENDED: + account_result.update({'status': value, 'suspended_at': datetime.utcnow()}) + elif value == AccountStatus.ACTIVE: + account_result.update({'status': value, 'suspended_at': None}) + else: + account_result.update({key: value}) except exc.NoResultFound: raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account) - if key == 'status': - if isinstance(value, str): - value = AccountStatus[value] - if value == AccountStatus.SUSPENDED: - account.update({'status': value, 'suspended_at': datetime.utcnow()}) - elif value == AccountStatus.ACTIVE: - account.update({'status': value, 'suspended_at': None}) - else: - account.update({key: value}) @stream_session -def list_accounts(filter_: dict = None, *, session: "Session") -> Generator[dict]: +def list_accounts(filter_: Optional[dict[str, Any]] = None, *, session: "Session") -> Iterator[dict]: """ Returns a list of all account names. :param filter_: Dictionary of attributes by which the input data should be filtered @@ -216,7 +214,7 @@ def list_accounts(filter_: dict = None, *, session: "Session") -> Generator[dict @read_session -def list_identities(account: InternalAccount, *, session: "Session"): +def list_identities(account: InternalAccount, *, session: "Session") -> list[dict[str, Any]]: """ List all identities on an account. @@ -251,7 +249,7 @@ def list_identities(account: InternalAccount, *, session: "Session"): @read_session -def list_account_attributes(account: InternalAccount, *, session: "Session") -> list[dict]: +def list_account_attributes(account: InternalAccount, *, session: "Session") -> list[dict[str, Any]]: """ Get all attributes defined for an account. @@ -301,7 +299,7 @@ def has_account_attribute(account: InternalAccount, key: str, *, session: "Sessi @transactional_session -def add_account_attribute(account: InternalAccount, key: str, value: Any, *, session: "Session"): +def add_account_attribute(account: InternalAccount, key: str, value: Any, *, session: "Session") -> None: """ Add an attribute for the given account name. @@ -337,7 +335,7 @@ def add_account_attribute(account: InternalAccount, key: str, value: Any, *, ses @transactional_session -def del_account_attribute(account: InternalAccount, key: str, *, session: "Session"): +def del_account_attribute(account: InternalAccount, key: str, *, session: "Session") -> None: """ Add an attribute for the given account name. @@ -358,7 +356,7 @@ def del_account_attribute(account: InternalAccount, key: str, *, session: "Sessi @read_session -def get_usage(rse_id: uuid.UUID, account: InternalAccount, *, session: "Session") -> dict: +def get_usage(rse_id: str, account: InternalAccount, *, session: "Session") -> dict: """ Returns current values of the specified counter, or raises CounterNotFound if the counter does not exist. @@ -403,14 +401,14 @@ def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Sessio @read_session -def get_usage_history(rse_id: uuid.UUID, account: InternalAccount, *, session: "Session") -> dict: +def get_usage_history(rse_id: str, account: InternalAccount, *, session: "Session") -> list[dict]: """ Returns historical values of the specified counter, or raises CounterNotFound if the counter does not exist. :param rse_id: The id of the RSE. :param account: The account name. :param session: The database session in use. - :returns: A dictionary {'bytes', 'files', 'updated_at'} + :returns: A list of dictionaries {'bytes', 'files', 'updated_at'} """ query = select( models.AccountUsageHistory.bytes, diff --git a/lib/rucio/core/account_counter.py b/lib/rucio/core/account_counter.py index f42a276d12..bc890e2285 100644 --- a/lib/rucio/core/account_counter.py +++ b/lib/rucio/core/account_counter.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import uuid import datetime from typing import TYPE_CHECKING @@ -25,13 +24,13 @@ from rucio.db.sqla.session import read_session, transactional_session if TYPE_CHECKING: - from sqlalchemy.orm import Session, Query + from sqlalchemy.orm import Session MAX_COUNTERS = 10 @transactional_session -def add_counter(rse_id: uuid.UUID, account: InternalAccount, *, session: "Session"): +def add_counter(rse_id: str, account: InternalAccount, *, session: "Session"): """ Creates the specified counter for a rse_id and account. @@ -44,7 +43,7 @@ def add_counter(rse_id: uuid.UUID, account: InternalAccount, *, session: "Sessio @transactional_session -def increase(rse_id: uuid.UUID, account: InternalAccount, files: int, bytes_: int, *, session: "Session"): +def increase(rse_id: str, account: InternalAccount, files: int, bytes_: int, *, session: "Session"): """ Increments the specified counter by the specified amount. @@ -58,7 +57,7 @@ def increase(rse_id: uuid.UUID, account: InternalAccount, files: int, bytes_: in @transactional_session -def decrease(rse_id: uuid.UUID, account: InternalAccount, files: int, bytes_: int, *, session: "Session") -> None: +def decrease(rse_id: str, account: InternalAccount, files: int, bytes_: int, *, session: "Session") -> None: """ Decreases the specified counter by the specified amount. @@ -72,7 +71,7 @@ def decrease(rse_id: uuid.UUID, account: InternalAccount, files: int, bytes_: in @transactional_session -def del_counter(rse_id: uuid.UUID, account: InternalAccount, *, session: "Session"): +def del_counter(rse_id: str, account: InternalAccount, *, session: "Session") -> None: """ Resets the specified counter and initializes it by the specified amounts. @@ -85,7 +84,7 @@ def del_counter(rse_id: uuid.UUID, account: InternalAccount, *, session: "Sessio @read_session -def get_updated_account_counters(total_workers: int, worker_number: int, *, session: "Session") -> list["Query"]: +def get_updated_account_counters(total_workers: int, worker_number: int, *, session: "Session") -> list[tuple[InternalAccount, str]]: """ Get updated rse_counters. @@ -108,7 +107,7 @@ def get_updated_account_counters(total_workers: int, worker_number: int, *, sess @transactional_session -def update_account_counter(account: InternalAccount, rse_id: uuid.UUID, *, session: "Session"): +def update_account_counter(account: InternalAccount, rse_id: str, *, session: "Session") -> None: """ Read the updated_account_counters and update the account_counter. @@ -134,7 +133,7 @@ def update_account_counter(account: InternalAccount, rse_id: uuid.UUID, *, sessi @transactional_session -def update_account_counter_history(account: InternalAccount, rse_id: uuid.UUID, *, session: "Session"): +def update_account_counter_history(account: InternalAccount, rse_id: str, *, session: "Session") -> None: """ Read the AccountUsage and update the AccountUsageHistory. diff --git a/lib/rucio/core/account_limit.py b/lib/rucio/core/account_limit.py index b1e4be592f..22343cc661 100644 --- a/lib/rucio/core/account_limit.py +++ b/lib/rucio/core/account_limit.py @@ -14,7 +14,7 @@ # limitations under the License. import uuid -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.sql import func, select, literal @@ -32,7 +32,7 @@ @read_session -def get_rse_account_usage(rse_id: uuid.UUID, *, session: "Session") -> list[dict]: +def get_rse_account_usage(rse_id: str, *, session: "Session") -> list[dict]: """ Returns the account limit and usage for all accounts on a RSE. @@ -92,7 +92,7 @@ def get_rse_account_usage(rse_id: uuid.UUID, *, session: "Session") -> list[dict @read_session -def get_global_account_limits(account: InternalAccount = None, *, session: "Session") -> dict: +def get_global_account_limits(account: Optional[InternalAccount] = None, *, session: "Session") -> dict: """ Returns the global account limits for the account. @@ -123,7 +123,7 @@ def get_global_account_limits(account: InternalAccount = None, *, session: "Sess @read_session -def get_global_account_limit(account: InternalAccount, rse_expression: str, *, session: "Session") -> int: +def get_global_account_limit(account: InternalAccount, rse_expression: str, *, session: "Session") -> Union[int, float, None]: """ Returns the global account limit for the account on the rse expression. @@ -143,7 +143,7 @@ def get_global_account_limit(account: InternalAccount, rse_expression: str, *, s @read_session -def get_local_account_limit(account: InternalAccount, rse_id: uuid.UUID, *, session: "Session") -> Optional[int | float]: +def get_local_account_limit(account: InternalAccount, rse_id: str, *, session: "Session") -> Union[int, float, None]: """ Returns the account limit for the account on the rse. @@ -164,7 +164,7 @@ def get_local_account_limit(account: InternalAccount, rse_id: uuid.UUID, *, sess @read_session -def get_local_account_limits(account: InternalAccount, rse_ids: list[uuid.UUID] = None, *, session: "Session") -> dict: +def get_local_account_limits(account: InternalAccount, rse_ids: Optional[list[str]] = None, *, session: "Session") -> dict: """ Returns the account limits for the account on the list of rses. @@ -198,7 +198,7 @@ def get_local_account_limits(account: InternalAccount, rse_ids: list[uuid.UUID] @transactional_session -def set_local_account_limit(account: InternalAccount, rse_id: uuid.UUID, bytes_: int, *, session: "Session"): +def set_local_account_limit(account: InternalAccount, rse_id: str, bytes_: int, *, session: "Session"): """ Returns the limits for the account on the rse. @@ -234,7 +234,7 @@ def set_global_account_limit(account: InternalAccount, rse_expression: str, byte @transactional_session -def delete_local_account_limit(account: InternalAccount, rse_id: uuid.UUID, *, session: "Session") -> bool: +def delete_local_account_limit(account: InternalAccount, rse_id: str, *, session: "Session") -> bool: """ Deletes a local account limit. @@ -270,7 +270,7 @@ def delete_global_account_limit(account: InternalAccount, rse_expression: str, * @transactional_session -def get_local_account_usage(account: InternalAccount, rse_id: uuid.UUID = None, *, session: "Session") -> list[dict]: +def get_local_account_usage(account: InternalAccount, rse_id: Optional[uuid.UUID] = None, *, session: "Session") -> list[dict]: """ Read the account usage and connect it with (if available) the account limits of the account. @@ -313,7 +313,7 @@ def get_local_account_usage(account: InternalAccount, rse_id: uuid.UUID = None, @transactional_session -def get_global_account_usage(account: InternalAccount, rse_expression: str = None, *, session: "Session") -> list[dict]: +def get_global_account_usage(account: InternalAccount, rse_expression: Optional[str] = None, *, session: "Session") -> list[dict]: """ Read the account usage and connect it with the global account limits of the account. diff --git a/lib/rucio/core/authentication.py b/lib/rucio/core/authentication.py index 6dca804865..44d0130a8e 100644 --- a/lib/rucio/core/authentication.py +++ b/lib/rucio/core/authentication.py @@ -20,7 +20,7 @@ import sys import traceback from base64 import b64decode -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Union import paramiko from dogpile.cache import make_region @@ -30,6 +30,7 @@ from rucio.common.cache import make_region_memcached from rucio.common.config import config_get_bool from rucio.common.exception import CannotAuthenticate, RucioException +from rucio.common.types import InternalAccount, TokenDict from rucio.common.utils import chunks, generate_uuid, date_to_str from rucio.core.account import account_exists from rucio.core.oidc import validate_jwt @@ -40,7 +41,7 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from typing import Any, Union + from typing import Any def strip_x509_proxy_attributes(dn: str) -> str: @@ -90,7 +91,7 @@ def generate_key(token, *, session: "Session"): @transactional_session -def get_auth_token_user_pass(account, username, password, appid, ip=None, *, session: "Session"): +def get_auth_token_user_pass(account: InternalAccount, username: str, password: str, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]: """ Authenticate a Rucio account temporarily via username and password. @@ -150,7 +151,7 @@ def get_auth_token_user_pass(account, username, password, appid, ip=None, *, ses @transactional_session -def get_auth_token_x509(account, dn, appid, ip=None, *, session: "Session"): +def get_auth_token_x509(account: InternalAccount, dn: str, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]: """ Authenticate a Rucio account temporarily via an x509 certificate. @@ -182,7 +183,7 @@ def get_auth_token_x509(account, dn, appid, ip=None, *, session: "Session"): @transactional_session -def get_auth_token_gss(account, gsstoken, appid, ip=None, *, session: "Session"): +def get_auth_token_gss(account: InternalAccount, gsstoken: str, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]: """ Authenticate a Rucio account temporarily via a GSS token. @@ -214,7 +215,7 @@ def get_auth_token_gss(account, gsstoken, appid, ip=None, *, session: "Session") @transactional_session -def get_auth_token_ssh(account, signature, appid, ip=None, *, session: "Session"): +def get_auth_token_ssh(account: InternalAccount, signature: Union[str, bytes], appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]: """ Authenticate a Rucio account temporarily via SSH key exchange. @@ -228,7 +229,7 @@ def get_auth_token_ssh(account, signature, appid, ip=None, *, session: "Session" :returns: A dict with token and expires_at entries. """ - if not isinstance(signature, bytes): + if isinstance(signature, str): signature = signature.encode() # Make sure the account exists @@ -306,7 +307,7 @@ def get_ssh_challenge_token(account: InternalAccount, appid: str, ip: Optional[s return None # Cryptographically secure random number. - # This requires a /dev/urandom like device from the OS + # This requires a /dev/urandom like device from the OS rng = random.SystemRandom() crypto_rand = rng.randint(0, sys.maxsize) @@ -324,7 +325,7 @@ def get_ssh_challenge_token(account: InternalAccount, appid: str, ip: Optional[s @transactional_session -def get_auth_token_saml(account, saml_nameid, appid, ip=None, *, session: "Session"): +def get_auth_token_saml(account: InternalAccount, saml_nameid: str, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]: """ Authenticate a Rucio account temporarily via SAML. @@ -355,7 +356,7 @@ def get_auth_token_saml(account, saml_nameid, appid, ip=None, *, session: "Sessi @transactional_session -def redirect_auth_oidc(auth_code, fetchtoken=False, *, session: "Session"): +def redirect_auth_oidc(auth_code: str, fetchtoken: bool = False, *, session: "Session") -> Optional[str]: """ Finds the Authentication URL in the Rucio DB oauth_requests table and redirects user's browser to this URL. @@ -396,7 +397,7 @@ def redirect_auth_oidc(auth_code, fetchtoken=False, *, session: "Session"): @transactional_session -def delete_expired_tokens(total_workers, worker_number, limit=1000, *, session: "Session"): +def delete_expired_tokens(total_workers: int, worker_number: int, limit: int = 1000, *, session: "Session") -> int: """ Delete expired tokens. @@ -456,7 +457,7 @@ def delete_expired_tokens(total_workers, worker_number, limit=1000, *, session: @read_session -def query_token(token, *, session: "Session"): +def query_token(token: str, *, session: "Session") -> Optional[dict]: """ Validate an authentication token using the database. This method will only be called if no entry could be found in the according cache. @@ -530,12 +531,12 @@ def validate_auth_token(token: str, *, session: "Session") -> "dict[str, Any]": return value -def token_dictionary(token: models.Token): +def token_dictionary(token: models.Token) -> TokenDict: return {'token': token.token, 'expires_at': token.expired_at} @transactional_session -def __delete_expired_tokens_account(account, *, session: "Session"): +def __delete_expired_tokens_account(account: InternalAccount, *, session: "Session") -> None: """" Deletes expired tokens from the database. diff --git a/lib/rucio/core/oidc.py b/lib/rucio/core/oidc.py index 54c42a87b5..0eea58c1cb 100644 --- a/lib/rucio/core/oidc.py +++ b/lib/rucio/core/oidc.py @@ -38,12 +38,12 @@ from sqlalchemy import delete, select, update from sqlalchemy.sql.expression import true -from rucio.common import types from rucio.common.cache import make_region_memcached from rucio.common.config import config_get, config_get_int from rucio.common.exception import (CannotAuthenticate, CannotAuthorize, RucioException) from rucio.common.stopwatch import Stopwatch +from rucio.common.types import InternalAccount, TokenDict from rucio.common.utils import all_oidc_req_claims_present, build_url, val_to_space_sep_str from rucio.core.account import account_exists from rucio.core.identity import exist_identity_account, get_default_account @@ -469,7 +469,7 @@ def get_auth_oidc(account: str, *, session: "Session", **kwargs) -> str: @transactional_session -def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session"): +def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session") -> Optional[dict]: """ After Rucio User got redirected to Rucio /auth/oidc_token (or /auth/oidc_code) REST endpoints with authz code and session state encoded within the URL. @@ -479,8 +479,8 @@ def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session" :param ip: IP address of the client as a string. :param session: The database session in use. - :returns: One of the following tuples: ("fetchcode", ); ("token", ); - ("polling", True); The result depends on the authentication strategy being used + :returns: One of the following dicts: {"fetchcode": }; {"token": }; + {"polling": True}; The result depends on the authentication strategy being used (no auto, auto, polling). """ try: @@ -630,13 +630,13 @@ def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session" @transactional_session -def __get_admin_token_oidc(account: types.InternalAccount, req_scope, req_audience, issuer, *, session: "Session"): +def __get_admin_token_oidc(account: InternalAccount, req_scope: str, req_audience: str, issuer: str, *, session: "Session") -> Optional[TokenDict]: """ Get a token for Rucio application to act on behalf of itself. client_credential flow is used for this purpose. No refresh token is expected to be used. - :param account: the Rucio Admin account name to be used (InternalAccount object expected) + :param account: the Rucio Admin account name to be used :param req_scope: the audience requested for the Rucio client's token :param req_audience: the scope requested for the Rucio client's token :param issuer: the Identity Provider nickname or the Rucio instance in use @@ -689,9 +689,9 @@ def __get_admin_token_oidc(account: types.InternalAccount, req_scope, req_audien @read_session -def __get_admin_account_for_issuer(*, session: "Session"): +def __get_admin_account_for_issuer(*, session: "Session") -> dict: """ Gets admin account for the IdP issuer - :returns : dictionary { 'issuer_1': (account, identity), ... } + :returns: dictionary { 'issuer_1': (account, identity), ... } """ if not OIDC_ADMIN_CLIENTS: @@ -715,7 +715,7 @@ def __get_admin_account_for_issuer(*, session: "Session"): @transactional_session -def get_token_for_account_operation(account: str, req_audience: str = None, req_scope: str = None, admin: bool = False, *, session: "Session"): +def get_token_for_account_operation(account: str, req_audience: str = None, req_scope: str = None, admin: bool = False, *, session: "Session") -> Optional[TokenDict]: """ Looks-up a JWT token with the required scope and audience claims with the account OIDC issuer. If tokens are found, and none contains the requested audience and scope a new token is requested @@ -801,7 +801,7 @@ def get_token_for_account_operation(account: str, req_audience: str = None, req_ for admin_token in admin_account_tokens: if hasattr(admin_token, 'audience') and hasattr(admin_token, 'oidc_scope') and\ all_oidc_req_claims_present(admin_token.oidc_scope, admin_token.audience, req_scope, req_audience): - return token_dictionary(admin_token) + return {'token': admin_token.token, 'expires_at': admin_token.expired_at} # if not found request a new one new_admin_token = __get_admin_token_oidc(account, req_scope, req_audience, admin_issuer, session=session) return new_admin_token @@ -845,7 +845,7 @@ def get_token_for_account_operation(account: str, req_audience: str = None, req_ for admin_token in admin_account_tokens: if hasattr(admin_token, 'audience') and hasattr(admin_token, 'oidc_scope') and\ all_oidc_req_claims_present(admin_token.oidc_scope, admin_token.audience, req_scope, req_audience): - return token_dictionary(admin_token) + return {'token': admin_token.token, 'expires_at': admin_token.expired_at} # if no admin token existing was found for the issuer of the valid user token # we request a new one new_admin_token = __get_admin_token_oidc(admin_account, req_scope, req_audience, admin_issuer, session=session) @@ -868,7 +868,7 @@ def get_token_for_account_operation(account: str, req_audience: str = None, req_ for token in account_tokens: if hasattr(token, 'audience') and hasattr(token, 'oidc_scope'): if all_oidc_req_claims_present(token.oidc_scope, token.audience, req_scope, req_audience): - return token_dictionary(token) + return {'token': token.token, 'expires_at': token.expired_at} # from available tokens select preferentially the one which are being refreshed if hasattr(token, 'oidc_scope') and ('offline_access' in str(token['oidc_scope'])): subject_token = token @@ -1000,7 +1000,7 @@ def __change_refresh_state(token: str, refresh: bool = False, *, session: "Sessi @transactional_session -def refresh_cli_auth_token(token_string: str, account: str, *, session: "Session"): +def refresh_cli_auth_token(token_string: str, account: str, *, session: "Session") -> Optional[tuple[str, int]]: """ Checks if there is active refresh token and if so returns either active token with expiration timestamp or requests a new @@ -1079,7 +1079,7 @@ def refresh_cli_auth_token(token_string: str, account: str, *, session: "Session @transactional_session -def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int = 3600, limit: int = 1000, *, session: "Session"): +def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int = 3600, limit: int = 1000, *, session: "Session") -> int: """ Refreshes tokens which expired or will expire before (now + refreshrate) next run of this function and which have valid refresh token. @@ -1089,7 +1089,7 @@ def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int :param limit: Maximum number of tokens to refresh per call. :param session: Database session in use. - :return: numper of tokens refreshed + :return: number of tokens refreshed """ nrefreshed = 0 try: @@ -1135,7 +1135,7 @@ def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int @METRICS.time_it @transactional_session -def __refresh_token_oidc(token_object: models.Token, *, session: "Session"): +def __refresh_token_oidc(token_object: models.Token, *, session: "Session") -> Optional[TokenDict]: """ Requests new access and refresh tokens from the Identity Provider. Assumption: The Identity Provider issues refresh tokens for one time use only and @@ -1212,7 +1212,7 @@ def __refresh_token_oidc(token_object: models.Token, *, session: "Session"): @transactional_session -def delete_expired_oauthrequests(total_workers: int, worker_number: int, limit: int = 1000, *, session: "Session"): +def delete_expired_oauthrequests(total_workers: int, worker_number: int, limit: int = 1000, *, session: "Session") -> int: """ Delete expired OAuth request parameters. @@ -1266,7 +1266,7 @@ def __get_keyvalues_from_claims(token: str, keys=None): """ Extracting claims from token, e.g. scope and audience. :param token: the JWT to be unpacked - :param key: list of key names to extract from the token claims + :param keys: list of key names to extract from the token claims :returns: The list of unicode values under the key, throws an exception otherwise. """ @@ -1278,7 +1278,7 @@ def __get_keyvalues_from_claims(token: str, keys=None): for key in keys: value = '' if key in claims: - value = val_to_space_sep_str(claims[key]) + value = val_to_space_sep_str(claims[key]) # type: ignore resdict[key] = value return resdict except Exception as error: @@ -1286,7 +1286,7 @@ def __get_keyvalues_from_claims(token: str, keys=None): @read_session -def __get_rucio_jwt_dict(jwt: str, account=None, *, session: "Session"): +def __get_rucio_jwt_dict(jwt: str, account: Optional[InternalAccount] = None, *, session: "Session") -> Optional[dict]: """ Get a Rucio token dictionary from token claims. Check token expiration and find default Rucio @@ -1330,7 +1330,7 @@ def __get_rucio_jwt_dict(jwt: str, account=None, *, session: "Session"): @transactional_session -def __save_validated_token(token, valid_dict, extra_dict=None, *, session: "Session"): +def __save_validated_token(token: str, valid_dict: dict, extra_dict: Optional[dict] = None, *, session: "Session") -> TokenDict: """ Save JWT token to the Rucio DB. @@ -1357,7 +1357,7 @@ def __save_validated_token(token, valid_dict, extra_dict=None, *, session: "Sess ip=extra_dict.get('ip', None)) new_token.save(session=session) - return token_dictionary(new_token) + return {'token': new_token.token, 'expires_at': new_token.expired_at} except Exception as error: raise RucioException(error.args) from error @@ -1433,7 +1433,7 @@ def validate_jwt(json_web_token: str, *, session: "Session") -> dict[str, Any]: raise CannotAuthenticate(traceback.format_exc()) -def oidc_identity_string(sub: str, iss: str): +def oidc_identity_string(sub: str, iss: str) -> str: """ Transform IdP sub claim and issuers url into users identity string. :param sub: users SUB claim from the Identity Provider @@ -1442,7 +1442,3 @@ def oidc_identity_string(sub: str, iss: str): :returns: OIDC identity string "SUB=, ISS=https://iam-test.ch/" """ return 'SUB=' + str(sub) + ', ISS=' + str(iss) - - -def token_dictionary(token: models.Token): - return {'token': token.token, 'expires_at': token.expired_at}