Skip to content

Commit

Permalink
Refactor to use TokenDict dataclass and add type hints; rucio#6454
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio committed Feb 13, 2024
1 parent 0997de5 commit 811b5ef
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 76 deletions.
8 changes: 8 additions & 0 deletions lib/rucio/common/types.py
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.

from typing import Any, Callable, Optional, TypedDict, Union
from datetime import datetime
from dataclasses import dataclass


class InternalType(object):
Expand Down Expand Up @@ -159,3 +161,9 @@ class RSESettingsDict(TypedDict):
deterministic: bool
domain: list[str]
protocols: list[RSEProtocolDict]


@dataclass
class TokenDict:
token: str
expires_at: datetime
36 changes: 17 additions & 19 deletions lib/rucio/core/account.py
Expand Up @@ -17,8 +17,7 @@
from enum import Enum
from re import match
from traceback import format_exc
from typing import TYPE_CHECKING, Any
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Iterator, Optional
import uuid

from sqlalchemy import select, and_
Expand Down Expand Up @@ -84,7 +83,7 @@ def account_exists(account: InternalAccount, *, session: "Session") -> bool:


@read_session
def get_account(account: InternalAccount, *, session: "Session") -> dict:
def get_account(account: InternalAccount, *, session: "Session") -> models.Account:
""" Returns an account for the given account name.
:param account: the name of the account.
Expand Down Expand Up @@ -118,12 +117,11 @@ def del_account(account: InternalAccount, *, session: "Session"):
models.Account.status == AccountStatus.ACTIVE
)
try:
account = session.execute(query).scalar_one()
account_result = session.execute(query).scalar_one()
account_result.update({'status': AccountStatus.DELETED, 'deleted_at': datetime.utcnow()})
except exc.NoResultFound:
raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account)

account.update({'status': AccountStatus.DELETED, 'deleted_at': datetime.utcnow()})


@transactional_session
def update_account(account: InternalAccount, key: str, value: Any, *, session: "Session"):
Expand All @@ -140,22 +138,22 @@ def update_account(account: InternalAccount, key: str, value: Any, *, session: "
models.Account.account == account
)
try:
account = session.execute(query).scalar_one()
account_result = session.execute(query).scalar_one()
if key == 'status':
if isinstance(value, str):
value = AccountStatus[value]
if value == AccountStatus.SUSPENDED:
account_result.update({'status': value, 'suspended_at': datetime.utcnow()})
elif value == AccountStatus.ACTIVE:
account_result.update({'status': value, 'suspended_at': None})
else:
account_result.update({key: value})
except exc.NoResultFound:
raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account)
if key == 'status':
if isinstance(value, str):
value = AccountStatus[value]
if value == AccountStatus.SUSPENDED:
account.update({'status': value, 'suspended_at': datetime.utcnow()})
elif value == AccountStatus.ACTIVE:
account.update({'status': value, 'suspended_at': None})
else:
account.update({key: value})


@stream_session
def list_accounts(filter_: dict = None, *, session: "Session") -> Generator[dict]:
def list_accounts(filter_: Optional[dict] = None, *, session: "Session") -> Iterator[dict]:
""" Returns a list of all account names.
:param filter_: Dictionary of attributes by which the input data should be filtered
Expand Down Expand Up @@ -403,14 +401,14 @@ def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Sessio


@read_session
def get_usage_history(rse_id: uuid.UUID, account: InternalAccount, *, session: "Session") -> dict:
def get_usage_history(rse_id: uuid.UUID, account: InternalAccount, *, session: "Session") -> list[dict]:
"""
Returns historical values of the specified counter, or raises CounterNotFound if the counter does not exist.
:param rse_id: The id of the RSE.
:param account: The account name.
:param session: The database session in use.
:returns: A dictionary {'bytes', 'files', 'updated_at'}
:returns: A list of dictionaries {'bytes', 'files', 'updated_at'}
"""
query = select(
models.AccountUsageHistory.bytes,
Expand Down
10 changes: 5 additions & 5 deletions lib/rucio/core/account_limit.py
Expand Up @@ -92,7 +92,7 @@ def get_rse_account_usage(rse_id: uuid.UUID, *, session: "Session") -> list[dict


@read_session
def get_global_account_limits(account: InternalAccount = None, *, session: "Session") -> dict:
def get_global_account_limits(account: Optional[InternalAccount] = None, *, session: "Session") -> dict:
"""
Returns the global account limits for the account.
Expand Down Expand Up @@ -123,7 +123,7 @@ def get_global_account_limits(account: InternalAccount = None, *, session: "Sess


@read_session
def get_global_account_limit(account: InternalAccount, rse_expression: str, *, session: "Session") -> int:
def get_global_account_limit(account: InternalAccount, rse_expression: str, *, session: "Session") -> Optional[int | float]:
"""
Returns the global account limit for the account on the rse expression.
Expand Down Expand Up @@ -164,7 +164,7 @@ def get_local_account_limit(account: InternalAccount, rse_id: uuid.UUID, *, sess


@read_session
def get_local_account_limits(account: InternalAccount, rse_ids: list[uuid.UUID] = None, *, session: "Session") -> dict:
def get_local_account_limits(account: InternalAccount, rse_ids: Optional[list[uuid.UUID]] = None, *, session: "Session") -> dict:
"""
Returns the account limits for the account on the list of rses.
Expand Down Expand Up @@ -270,7 +270,7 @@ def delete_global_account_limit(account: InternalAccount, rse_expression: str, *


@transactional_session
def get_local_account_usage(account: InternalAccount, rse_id: uuid.UUID = None, *, session: "Session") -> list[dict]:
def get_local_account_usage(account: InternalAccount, rse_id: Optional[uuid.UUID] = None, *, session: "Session") -> list[dict]:
"""
Read the account usage and connect it with (if available) the account limits of the account.
Expand Down Expand Up @@ -313,7 +313,7 @@ def get_local_account_usage(account: InternalAccount, rse_id: uuid.UUID = None,


@transactional_session
def get_global_account_usage(account: InternalAccount, rse_expression: str = None, *, session: "Session") -> list[dict]:
def get_global_account_usage(account: InternalAccount, rse_expression: Optional[str] = None, *, session: "Session") -> list[dict]:
"""
Read the account usage and connect it with the global account limits of the account.
Expand Down
41 changes: 19 additions & 22 deletions lib/rucio/core/authentication.py
Expand Up @@ -20,7 +20,7 @@
import sys
import traceback
from base64 import b64decode
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import paramiko
from dogpile.cache import make_region
Expand All @@ -30,6 +30,7 @@
from rucio.common.cache import make_region_memcached
from rucio.common.config import config_get_bool
from rucio.common.exception import CannotAuthenticate, RucioException
from rucio.common.types import InternalAccount, TokenDict
from rucio.common.utils import chunks, generate_uuid, date_to_str
from rucio.core.account import account_exists
from rucio.core.oidc import validate_jwt
Expand Down Expand Up @@ -90,7 +91,7 @@ def generate_key(token, *, session: "Session"):


@transactional_session
def get_auth_token_user_pass(account, username, password, appid, ip=None, *, session: "Session"):
def get_auth_token_user_pass(account: InternalAccount, username: str, password: str, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]:
"""
Authenticate a Rucio account temporarily via username and password.
Expand Down Expand Up @@ -146,11 +147,11 @@ def get_auth_token_user_pass(account, username, password, appid, ip=None, *, ses
new_token = models.Token(account=db_account, identity=username, token=token, ip=ip)
new_token.save(session=session)

return token_dictionary(new_token)
return TokenDict(new_token.token, new_token.expired_at)


@transactional_session
def get_auth_token_x509(account, dn, appid, ip=None, *, session: "Session"):
def get_auth_token_x509(account: InternalAccount, dn: str, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]:
"""
Authenticate a Rucio account temporarily via an x509 certificate.
Expand Down Expand Up @@ -178,11 +179,11 @@ def get_auth_token_x509(account, dn, appid, ip=None, *, session: "Session"):
new_token = models.Token(account=account, identity=dn, token=token, ip=ip)
new_token.save(session=session)

return token_dictionary(new_token)
return TokenDict(new_token.token, new_token.expired_at)


@transactional_session
def get_auth_token_gss(account, gsstoken, appid, ip=None, *, session: "Session"):
def get_auth_token_gss(account: InternalAccount, gsstoken: str, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]:
"""
Authenticate a Rucio account temporarily via a GSS token.
Expand Down Expand Up @@ -210,11 +211,11 @@ def get_auth_token_gss(account, gsstoken, appid, ip=None, *, session: "Session")
new_token = models.Token(account=account, token=token, ip=ip)
new_token.save(session=session)

return token_dictionary(new_token)
return TokenDict(new_token.token, new_token.expired_at)


@transactional_session
def get_auth_token_ssh(account, signature, appid, ip=None, *, session: "Session"):
def get_auth_token_ssh(account: InternalAccount, signature: str | bytes, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]:
"""
Authenticate a Rucio account temporarily via SSH key exchange.
Expand All @@ -228,7 +229,7 @@ def get_auth_token_ssh(account, signature, appid, ip=None, *, session: "Session"
:returns: A dict with token and expires_at entries.
"""
if not isinstance(signature, bytes):
if isinstance(signature, str):
signature = signature.encode()

# Make sure the account exists
Expand Down Expand Up @@ -284,11 +285,11 @@ def get_auth_token_ssh(account, signature, appid, ip=None, *, session: "Session"
new_token = models.Token(account=account, token=token, ip=ip)
new_token.save(session=session)

return token_dictionary(new_token)
return TokenDict(new_token.token, new_token.expired_at)


@transactional_session
def get_ssh_challenge_token(account, appid, ip=None, *, session: "Session"):
def get_ssh_challenge_token(account: InternalAccount, appid, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]:
"""
Prepare a challenge token for subsequent SSH public key authentication.
Expand Down Expand Up @@ -320,11 +321,11 @@ def get_ssh_challenge_token(account, appid, ip=None, *, session: "Session"):
expired_at=expiration)
new_challenge_token.save(session=session)

return token_dictionary(new_challenge_token)
return TokenDict(new_challenge_token.token, new_challenge_token.expired_at)


@transactional_session
def get_auth_token_saml(account, saml_nameid, appid, ip=None, *, session: "Session"):
def get_auth_token_saml(account: InternalAccount, saml_nameid: str, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]:
"""
Authenticate a Rucio account temporarily via SAML.
Expand All @@ -351,11 +352,11 @@ def get_auth_token_saml(account, saml_nameid, appid, ip=None, *, session: "Sessi
new_token = models.Token(account=account, identity=saml_nameid, token=token, ip=ip)
new_token.save(session=session)

return token_dictionary(new_token)
return TokenDict(new_token.token, new_token.expired_at)


@transactional_session
def redirect_auth_oidc(auth_code, fetchtoken=False, *, session: "Session"):
def redirect_auth_oidc(auth_code: str, fetchtoken: bool = False, *, session: "Session") -> Optional[str]:
"""
Finds the Authentication URL in the Rucio DB oauth_requests table
and redirects user's browser to this URL.
Expand Down Expand Up @@ -396,7 +397,7 @@ def redirect_auth_oidc(auth_code, fetchtoken=False, *, session: "Session"):


@transactional_session
def delete_expired_tokens(total_workers, worker_number, limit=1000, *, session: "Session"):
def delete_expired_tokens(total_workers: int, worker_number: int, limit: int = 1000, *, session: "Session") -> int:
"""
Delete expired tokens.
Expand Down Expand Up @@ -456,7 +457,7 @@ def delete_expired_tokens(total_workers, worker_number, limit=1000, *, session:


@read_session
def query_token(token, *, session: "Session"):
def query_token(token: str, *, session: "Session") -> Optional[dict]:
"""
Validate an authentication token using the database. This method will only be called
if no entry could be found in the according cache.
Expand Down Expand Up @@ -530,12 +531,8 @@ def validate_auth_token(token: str, *, session: "Session") -> "dict[str, Any]":
return value


def token_dictionary(token: models.Token):
return {'token': token.token, 'expires_at': token.expired_at}


@transactional_session
def __delete_expired_tokens_account(account, *, session: "Session"):
def __delete_expired_tokens_account(account: InternalAccount, *, session: "Session"):
""""
Deletes expired tokens from the database.
Expand Down

0 comments on commit 811b5ef

Please sign in to comment.