Skip to content

Commit

Permalink
Refactor to use TokenDict dataclass and add type hints; rucio#6454
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio committed Feb 13, 2024
1 parent 0997de5 commit 5232432
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 47 deletions.
7 changes: 7 additions & 0 deletions lib/rucio/common/types.py
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.

from typing import Any, Callable, Optional, TypedDict, Union
from datetime import datetime
from dataclasses import dataclass


class InternalType(object):
Expand Down Expand Up @@ -159,3 +161,8 @@ class RSESettingsDict(TypedDict):
deterministic: bool
domain: list[str]
protocols: list[RSEProtocolDict]

@dataclass
class TokenDict:
token: str
expires_at: datetime
40 changes: 18 additions & 22 deletions lib/rucio/core/authentication.py
Expand Up @@ -20,7 +20,7 @@
import sys
import traceback
from base64 import b64decode
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import paramiko
from dogpile.cache import make_region
Expand All @@ -30,6 +30,7 @@
from rucio.common.cache import make_region_memcached
from rucio.common.config import config_get_bool
from rucio.common.exception import CannotAuthenticate, RucioException
from rucio.common.types import InternalAccount, TokenDict
from rucio.common.utils import chunks, generate_uuid, date_to_str
from rucio.core.account import account_exists
from rucio.core.oidc import validate_jwt
Expand Down Expand Up @@ -90,7 +91,7 @@ def generate_key(token, *, session: "Session"):


@transactional_session
def get_auth_token_user_pass(account, username, password, appid, ip=None, *, session: "Session"):
def get_auth_token_user_pass(account: InternalAccount, username: str, password: str, appid: str, ip: str = None, *, session: "Session") -> Optional[TokenDict]:
"""
Authenticate a Rucio account temporarily via username and password.
Expand Down Expand Up @@ -146,11 +147,10 @@ def get_auth_token_user_pass(account, username, password, appid, ip=None, *, ses
new_token = models.Token(account=db_account, identity=username, token=token, ip=ip)
new_token.save(session=session)

return token_dictionary(new_token)

return TokenDict(new_token.token, new_token.expired_at)

@transactional_session
def get_auth_token_x509(account, dn, appid, ip=None, *, session: "Session"):
def get_auth_token_x509(account: InternalAccount, dn: str, appid: str, ip: str = None, *, session: "Session") -> Optional[TokenDict]:
"""
Authenticate a Rucio account temporarily via an x509 certificate.
Expand Down Expand Up @@ -178,11 +178,11 @@ def get_auth_token_x509(account, dn, appid, ip=None, *, session: "Session"):
new_token = models.Token(account=account, identity=dn, token=token, ip=ip)
new_token.save(session=session)

return token_dictionary(new_token)
return TokenDict(new_token.token, new_token.expired_at)


@transactional_session
def get_auth_token_gss(account, gsstoken, appid, ip=None, *, session: "Session"):
def get_auth_token_gss(account: InternalAccount, gsstoken: str, appid: str, ip: str = None, *, session: "Session") -> Optional[TokenDict]:
"""
Authenticate a Rucio account temporarily via a GSS token.
Expand Down Expand Up @@ -210,11 +210,11 @@ def get_auth_token_gss(account, gsstoken, appid, ip=None, *, session: "Session")
new_token = models.Token(account=account, token=token, ip=ip)
new_token.save(session=session)

return token_dictionary(new_token)
return TokenDict(new_token.token, new_token.expired_at)


@transactional_session
def get_auth_token_ssh(account, signature, appid, ip=None, *, session: "Session"):
def get_auth_token_ssh(account: InternalAccount, signature: str | bytes, appid: str, ip: str = None, *, session: "Session") -> Optional[TokenDict]:
"""
Authenticate a Rucio account temporarily via SSH key exchange.
Expand Down Expand Up @@ -284,11 +284,11 @@ def get_auth_token_ssh(account, signature, appid, ip=None, *, session: "Session"
new_token = models.Token(account=account, token=token, ip=ip)
new_token.save(session=session)

return token_dictionary(new_token)
return TokenDict(new_token.token, new_token.expired_at)


@transactional_session
def get_ssh_challenge_token(account, appid, ip=None, *, session: "Session"):
def get_ssh_challenge_token(account: InternalAccount, appid, ip: str = None, *, session: "Session") -> Optional[TokenDict]:
"""
Prepare a challenge token for subsequent SSH public key authentication.
Expand Down Expand Up @@ -320,11 +320,11 @@ def get_ssh_challenge_token(account, appid, ip=None, *, session: "Session"):
expired_at=expiration)
new_challenge_token.save(session=session)

return token_dictionary(new_challenge_token)
return TokenDict(new_challenge_token.token, new_challenge_token.expired_at)


@transactional_session
def get_auth_token_saml(account, saml_nameid, appid, ip=None, *, session: "Session"):
def get_auth_token_saml(account: InternalAccount, saml_nameid: str, appid: str, ip: str = None, *, session: "Session") -> Optional[TokenDict]:
"""
Authenticate a Rucio account temporarily via SAML.
Expand All @@ -351,11 +351,11 @@ def get_auth_token_saml(account, saml_nameid, appid, ip=None, *, session: "Sessi
new_token = models.Token(account=account, identity=saml_nameid, token=token, ip=ip)
new_token.save(session=session)

return token_dictionary(new_token)
return TokenDict(new_token.token, new_token.expired_at)


@transactional_session
def redirect_auth_oidc(auth_code, fetchtoken=False, *, session: "Session"):
def redirect_auth_oidc(auth_code: str, fetchtoken: bool = False, *, session: "Session") -> Optional[dict]:
"""
Finds the Authentication URL in the Rucio DB oauth_requests table
and redirects user's browser to this URL.
Expand Down Expand Up @@ -396,7 +396,7 @@ def redirect_auth_oidc(auth_code, fetchtoken=False, *, session: "Session"):


@transactional_session
def delete_expired_tokens(total_workers, worker_number, limit=1000, *, session: "Session"):
def delete_expired_tokens(total_workers: int, worker_number: int, limit: int = 1000, *, session: "Session") -> int:
"""
Delete expired tokens.
Expand Down Expand Up @@ -456,7 +456,7 @@ def delete_expired_tokens(total_workers, worker_number, limit=1000, *, session:


@read_session
def query_token(token, *, session: "Session"):
def query_token(token: str, *, session: "Session") -> dict:
"""
Validate an authentication token using the database. This method will only be called
if no entry could be found in the according cache.
Expand Down Expand Up @@ -530,12 +530,8 @@ def validate_auth_token(token: str, *, session: "Session") -> "dict[str, Any]":
return value


def token_dictionary(token: models.Token):
return {'token': token.token, 'expires_at': token.expired_at}


@transactional_session
def __delete_expired_tokens_account(account, *, session: "Session"):
def __delete_expired_tokens_account(account: InternalAccount, *, session: "Session"):
""""
Deletes expired tokens from the database.
Expand Down
46 changes: 21 additions & 25 deletions lib/rucio/core/oidc.py
Expand Up @@ -38,12 +38,12 @@
from sqlalchemy import delete, select, update
from sqlalchemy.sql.expression import true

from rucio.common import types
from rucio.common.cache import make_region_memcached
from rucio.common.config import config_get, config_get_int
from rucio.common.exception import (CannotAuthenticate, CannotAuthorize,
RucioException)
from rucio.common.stopwatch import Stopwatch
from rucio.common.types import InternalAccount, TokenDict
from rucio.common.utils import all_oidc_req_claims_present, build_url, val_to_space_sep_str
from rucio.core.account import account_exists
from rucio.core.identity import exist_identity_account, get_default_account
Expand Down Expand Up @@ -469,7 +469,7 @@ def get_auth_oidc(account: str, *, session: "Session", **kwargs) -> str:


@transactional_session
def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session"):
def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session") -> tuple:
"""
After Rucio User got redirected to Rucio /auth/oidc_token (or /auth/oidc_code)
REST endpoints with authz code and session state encoded within the URL.
Expand Down Expand Up @@ -630,13 +630,13 @@ def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session"


@transactional_session
def __get_admin_token_oidc(account: types.InternalAccount, req_scope, req_audience, issuer, *, session: "Session"):
def __get_admin_token_oidc(account: InternalAccount, req_scope: str, req_audience: str, issuer: str, *, session: "Session") -> Optional[TokenDict]:
"""
Get a token for Rucio application to act on behalf of itself.
client_credential flow is used for this purpose.
No refresh token is expected to be used.
:param account: the Rucio Admin account name to be used (InternalAccount object expected)
:param account: the Rucio Admin account name to be used
:param req_scope: the audience requested for the Rucio client's token
:param req_audience: the scope requested for the Rucio client's token
:param issuer: the Identity Provider nickname or the Rucio instance in use
Expand Down Expand Up @@ -689,9 +689,9 @@ def __get_admin_token_oidc(account: types.InternalAccount, req_scope, req_audien


@read_session
def __get_admin_account_for_issuer(*, session: "Session"):
def __get_admin_account_for_issuer(*, session: "Session") -> dict:
""" Gets admin account for the IdP issuer
:returns : dictionary { 'issuer_1': (account, identity), ... }
:returns: dictionary { 'issuer_1': (account, identity), ... }
"""

if not OIDC_ADMIN_CLIENTS:
Expand All @@ -715,7 +715,7 @@ def __get_admin_account_for_issuer(*, session: "Session"):


@transactional_session
def get_token_for_account_operation(account: str, req_audience: str = None, req_scope: str = None, admin: bool = False, *, session: "Session"):
def get_token_for_account_operation(account: str, req_audience: str = None, req_scope: str = None, admin: bool = False, *, session: "Session") -> Optional[TokenDict]:
"""
Looks-up a JWT token with the required scope and audience claims with the account OIDC issuer.
If tokens are found, and none contains the requested audience and scope a new token is requested
Expand Down Expand Up @@ -801,7 +801,7 @@ def get_token_for_account_operation(account: str, req_audience: str = None, req_
for admin_token in admin_account_tokens:
if hasattr(admin_token, 'audience') and hasattr(admin_token, 'oidc_scope') and\
all_oidc_req_claims_present(admin_token.oidc_scope, admin_token.audience, req_scope, req_audience):
return token_dictionary(admin_token)
return TokenDict(admin_token.token, admin_token.expired_at)
# if not found request a new one
new_admin_token = __get_admin_token_oidc(account, req_scope, req_audience, admin_issuer, session=session)
return new_admin_token
Expand Down Expand Up @@ -845,7 +845,7 @@ def get_token_for_account_operation(account: str, req_audience: str = None, req_
for admin_token in admin_account_tokens:
if hasattr(admin_token, 'audience') and hasattr(admin_token, 'oidc_scope') and\
all_oidc_req_claims_present(admin_token.oidc_scope, admin_token.audience, req_scope, req_audience):
return token_dictionary(admin_token)
return TokenDict(admin_token.token, admin_token.expired_at)
# if no admin token existing was found for the issuer of the valid user token
# we request a new one
new_admin_token = __get_admin_token_oidc(admin_account, req_scope, req_audience, admin_issuer, session=session)
Expand All @@ -868,7 +868,7 @@ def get_token_for_account_operation(account: str, req_audience: str = None, req_
for token in account_tokens:
if hasattr(token, 'audience') and hasattr(token, 'oidc_scope'):
if all_oidc_req_claims_present(token.oidc_scope, token.audience, req_scope, req_audience):
return token_dictionary(token)
return TokenDict(token.token, token.expired_at)
# from available tokens select preferentially the one which are being refreshed
if hasattr(token, 'oidc_scope') and ('offline_access' in str(token['oidc_scope'])):
subject_token = token
Expand Down Expand Up @@ -1000,7 +1000,7 @@ def __change_refresh_state(token: str, refresh: bool = False, *, session: "Sessi


@transactional_session
def refresh_cli_auth_token(token_string: str, account: str, *, session: "Session"):
def refresh_cli_auth_token(token_string: str, account: str, *, session: "Session") -> Optional[tuple]:
"""
Checks if there is active refresh token and if so returns
either active token with expiration timestamp or requests a new
Expand Down Expand Up @@ -1079,7 +1079,7 @@ def refresh_cli_auth_token(token_string: str, account: str, *, session: "Session


@transactional_session
def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int = 3600, limit: int = 1000, *, session: "Session"):
def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int = 3600, limit: int = 1000, *, session: "Session") -> int:
"""
Refreshes tokens which expired or will expire before (now + refreshrate)
next run of this function and which have valid refresh token.
Expand All @@ -1089,7 +1089,7 @@ def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int
:param limit: Maximum number of tokens to refresh per call.
:param session: Database session in use.
:return: numper of tokens refreshed
:return: number of tokens refreshed
"""
nrefreshed = 0
try:
Expand Down Expand Up @@ -1135,7 +1135,7 @@ def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int

@METRICS.time_it
@transactional_session
def __refresh_token_oidc(token_object: models.Token, *, session: "Session"):
def __refresh_token_oidc(token_object: models.Token, *, session: "Session") -> Optional[TokenDict]:
"""
Requests new access and refresh tokens from the Identity Provider.
Assumption: The Identity Provider issues refresh tokens for one time use only and
Expand Down Expand Up @@ -1212,7 +1212,7 @@ def __refresh_token_oidc(token_object: models.Token, *, session: "Session"):


@transactional_session
def delete_expired_oauthrequests(total_workers: int, worker_number: int, limit: int = 1000, *, session: "Session"):
def delete_expired_oauthrequests(total_workers: int, worker_number: int, limit: int = 1000, *, session: "Session") -> int:
"""
Delete expired OAuth request parameters.
Expand Down Expand Up @@ -1262,7 +1262,7 @@ def delete_expired_oauthrequests(total_workers: int, worker_number: int, limit:
raise RucioException(error.args) from error


def __get_keyvalues_from_claims(token: str, keys=None):
def __get_keyvalues_from_claims(token: str, keys: list[str] = None) -> list:
"""
Extracting claims from token, e.g. scope and audience.
:param token: the JWT to be unpacked
Expand All @@ -1286,7 +1286,7 @@ def __get_keyvalues_from_claims(token: str, keys=None):


@read_session
def __get_rucio_jwt_dict(jwt: str, account=None, *, session: "Session"):
def __get_rucio_jwt_dict(jwt: str, account: InternalAccount = None, *, session: "Session") -> Optional[dict]:
"""
Get a Rucio token dictionary from token claims.
Check token expiration and find default Rucio
Expand Down Expand Up @@ -1330,7 +1330,7 @@ def __get_rucio_jwt_dict(jwt: str, account=None, *, session: "Session"):


@transactional_session
def __save_validated_token(token, valid_dict, extra_dict=None, *, session: "Session"):
def __save_validated_token(token: str, valid_dict: dict, extra_dict: dict = None, *, session: "Session") -> TokenDict:
"""
Save JWT token to the Rucio DB.
Expand All @@ -1357,7 +1357,7 @@ def __save_validated_token(token, valid_dict, extra_dict=None, *, session: "Sess
ip=extra_dict.get('ip', None))
new_token.save(session=session)

return token_dictionary(new_token)
return TokenDict(new_token.token, new_token.expired_at)

except Exception as error:
raise RucioException(error.args) from error
Expand Down Expand Up @@ -1433,16 +1433,12 @@ def validate_jwt(json_web_token: str, *, session: "Session") -> dict[str, Any]:
raise CannotAuthenticate(traceback.format_exc())


def oidc_identity_string(sub: str, iss: str):
def oidc_identity_string(sub: str, iss: str) -> str:
"""
Transform IdP sub claim and issuers url into users identity string.
:param sub: users SUB claim from the Identity Provider
:param iss: issuer (IdP) https url
:returns: OIDC identity string "SUB=<usersid>, ISS=https://iam-test.ch/"
"""
return 'SUB=' + str(sub) + ', ISS=' + str(iss)


def token_dictionary(token: models.Token):
return {'token': token.token, 'expires_at': token.expired_at}
return 'SUB=' + str(sub) + ', ISS=' + str(iss)

0 comments on commit 5232432

Please sign in to comment.