Skip to content

Commit

Permalink
Add type hints; rucio#6454
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio committed Feb 13, 2024
1 parent 474ecdc commit 0997de5
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 36 deletions.
34 changes: 18 additions & 16 deletions lib/rucio/core/account.py
Expand Up @@ -17,7 +17,9 @@
from enum import Enum
from re import match
from traceback import format_exc
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from collections.abc import Generator
import uuid

from sqlalchemy import select, and_
from sqlalchemy.exc import IntegrityError
Expand All @@ -27,6 +29,7 @@
import rucio.core.rse
from rucio.common import exception
from rucio.common.config import config_get_bool
from rucio.common.types import InternalAccount
from rucio.core.vo import vo_exists
from rucio.db.sqla import models
from rucio.db.sqla.constants import AccountStatus, AccountType
Expand All @@ -37,7 +40,7 @@


@transactional_session
def add_account(account, type_, email, *, session: "Session"):
def add_account(account: InternalAccount, type_: AccountType, email: str, *, session: "Session"):
""" Add an account with the given account name and type.
:param account: the name of the new account.
Expand All @@ -63,7 +66,7 @@ def add_account(account, type_, email, *, session: "Session"):


@read_session
def account_exists(account, *, session: "Session"):
def account_exists(account: InternalAccount, *, session: "Session") -> bool:
""" Checks to see if account exists and is active.
:param account: Name of the account.
Expand All @@ -81,7 +84,7 @@ def account_exists(account, *, session: "Session"):


@read_session
def get_account(account, *, session: "Session"):
def get_account(account: InternalAccount, *, session: "Session") -> dict:
""" Returns an account for the given account name.
:param account: the name of the account.
Expand All @@ -102,7 +105,7 @@ def get_account(account, *, session: "Session"):


@transactional_session
def del_account(account, *, session: "Session"):
def del_account(account: InternalAccount, *, session: "Session"):
""" Disable an account with the given account name.
:param account: the account name.
Expand All @@ -123,7 +126,7 @@ def del_account(account, *, session: "Session"):


@transactional_session
def update_account(account, key, value, *, session: "Session"):
def update_account(account: InternalAccount, key: str, value: Any, *, session: "Session"):
""" Update a property of an account.
:param account: Name of the account.
Expand Down Expand Up @@ -152,7 +155,7 @@ def update_account(account, key, value, *, session: "Session"):


@stream_session
def list_accounts(filter_=None, *, session: "Session"):
def list_accounts(filter_: dict = None, *, session: "Session") -> Generator[dict]:
""" Returns a list of all account names.
:param filter_: Dictionary of attributes by which the input data should be filtered
Expand Down Expand Up @@ -213,7 +216,7 @@ def list_accounts(filter_=None, *, session: "Session"):


@read_session
def list_identities(account, *, session: "Session"):
def list_identities(account: InternalAccount, *, session: "Session"):
"""
List all identities on an account.
Expand Down Expand Up @@ -248,7 +251,7 @@ def list_identities(account, *, session: "Session"):


@read_session
def list_account_attributes(account, *, session: "Session"):
def list_account_attributes(account: InternalAccount, *, session: "Session") -> list[dict]:
"""
Get all attributes defined for an account.
Expand Down Expand Up @@ -278,7 +281,7 @@ def list_account_attributes(account, *, session: "Session"):


@read_session
def has_account_attribute(account, key, *, session: "Session"):
def has_account_attribute(account: InternalAccount, key: str, *, session: "Session") -> bool:
"""
Indicates whether the named key is present for the account.
Expand All @@ -298,7 +301,7 @@ def has_account_attribute(account, key, *, session: "Session"):


@transactional_session
def add_account_attribute(account, key, value, *, session: "Session"):
def add_account_attribute(account: InternalAccount, key: str, value: Any, *, session: "Session"):
"""
Add an attribute for the given account name.
Expand Down Expand Up @@ -334,7 +337,7 @@ def add_account_attribute(account, key, value, *, session: "Session"):


@transactional_session
def del_account_attribute(account, key, *, session: "Session"):
def del_account_attribute(account: InternalAccount, key: str, *, session: "Session"):
"""
Add an attribute for the given account name.
Expand All @@ -355,7 +358,7 @@ def del_account_attribute(account, key, *, session: "Session"):


@read_session
def get_usage(rse_id, account, *, session: "Session"):
def get_usage(rse_id: uuid.UUID, account: InternalAccount, *, session: "Session") -> dict:
"""
Returns current values of the specified counter, or raises CounterNotFound if the counter does not exist.
Expand All @@ -379,7 +382,7 @@ def get_usage(rse_id, account, *, session: "Session"):


@read_session
def get_all_rse_usages_per_account(account, *, session: "Session"):
def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Session") -> list[dict]:
"""
Returns current values of the specified counter, or raises CounterNotFound if the counter does not exist.
Expand All @@ -400,7 +403,7 @@ def get_all_rse_usages_per_account(account, *, session: "Session"):


@read_session
def get_usage_history(rse_id, account, *, session: "Session"):
def get_usage_history(rse_id: uuid.UUID, account: InternalAccount, *, session: "Session") -> dict:
"""
Returns historical values of the specified counter, or raises CounterNotFound if the counter does not exist.
Expand All @@ -423,4 +426,3 @@ def get_usage_history(rse_id, account, *, session: "Session"):
return [row._asdict() for row in session.execute(query)]
except exc.NoResultFound:
raise exception.CounterNotFound('No usage can be found for account %s on RSE %s' % (account, rucio.core.rse.get_rse_name(rse_id=rse_id, session=session)))
return []
19 changes: 11 additions & 8 deletions lib/rucio/core/account_counter.py
Expand Up @@ -12,23 +12,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid
import datetime
from typing import TYPE_CHECKING

from sqlalchemy import literal, insert, select
from sqlalchemy.orm.exc import NoResultFound

from rucio.common.types import InternalAccount
from rucio.db.sqla import models, filter_thread_work
from rucio.db.sqla.session import read_session, transactional_session

if TYPE_CHECKING:
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, Query

MAX_COUNTERS = 10


@transactional_session
def add_counter(rse_id, account, *, session: "Session"):
def add_counter(rse_id: uuid.UUID, account: InternalAccount, *, session: "Session"):
"""
Creates the specified counter for a rse_id and account.
Expand All @@ -41,7 +44,7 @@ def add_counter(rse_id, account, *, session: "Session"):


@transactional_session
def increase(rse_id, account, files, bytes_, *, session: "Session"):
def increase(rse_id: uuid.UUID, account: InternalAccount, files: int, bytes_: int, *, session: "Session"):
"""
Increments the specified counter by the specified amount.
Expand All @@ -55,7 +58,7 @@ def increase(rse_id, account, files, bytes_, *, session: "Session"):


@transactional_session
def decrease(rse_id, account, files, bytes_, *, session: "Session"):
def decrease(rse_id: uuid.UUID, account: InternalAccount, files: int, bytes_: int, *, session: "Session") -> None:
"""
Decreases the specified counter by the specified amount.
Expand All @@ -69,7 +72,7 @@ def decrease(rse_id, account, files, bytes_, *, session: "Session"):


@transactional_session
def del_counter(rse_id, account, *, session: "Session"):
def del_counter(rse_id: uuid.UUID, account: InternalAccount, *, session: "Session"):
"""
Resets the specified counter and initializes it by the specified amounts.
Expand All @@ -82,7 +85,7 @@ def del_counter(rse_id, account, *, session: "Session"):


@read_session
def get_updated_account_counters(total_workers, worker_number, *, session: "Session"):
def get_updated_account_counters(total_workers: int, worker_number: int, *, session: "Session") -> list["Query"]:
"""
Get updated rse_counters.
Expand All @@ -105,7 +108,7 @@ def get_updated_account_counters(total_workers, worker_number, *, session: "Sess


@transactional_session
def update_account_counter(account, rse_id, *, session: "Session"):
def update_account_counter(account: InternalAccount, rse_id: uuid.UUID, *, session: "Session"):
"""
Read the updated_account_counters and update the account_counter.
Expand All @@ -131,7 +134,7 @@ def update_account_counter(account, rse_id, *, session: "Session"):


@transactional_session
def update_account_counter_history(account, rse_id, *, session: "Session"):
def update_account_counter_history(account: InternalAccount, rse_id: uuid.UUID, *, session: "Session"):
"""
Read the AccountUsage and update the AccountUsageHistory.
Expand Down
26 changes: 14 additions & 12 deletions lib/rucio/core/account_limit.py
Expand Up @@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING
import uuid
from typing import TYPE_CHECKING, Optional

from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.sql import func, select, literal
from sqlalchemy.sql.expression import and_, or_

from rucio.common.types import InternalAccount
from rucio.core.account import get_all_rse_usages_per_account
from rucio.core.rse import get_rse_name
from rucio.core.rse_expression_parser import parse_expression
Expand All @@ -30,7 +32,7 @@


@read_session
def get_rse_account_usage(rse_id, *, session: "Session"):
def get_rse_account_usage(rse_id: uuid.UUID, *, session: "Session") -> list[dict]:
"""
Returns the account limit and usage for all accounts on a RSE.
Expand Down Expand Up @@ -90,7 +92,7 @@ def get_rse_account_usage(rse_id, *, session: "Session"):


@read_session
def get_global_account_limits(account=None, *, session: "Session"):
def get_global_account_limits(account: InternalAccount = None, *, session: "Session") -> dict:
"""
Returns the global account limits for the account.
Expand Down Expand Up @@ -121,7 +123,7 @@ def get_global_account_limits(account=None, *, session: "Session"):


@read_session
def get_global_account_limit(account, rse_expression, *, session: "Session"):
def get_global_account_limit(account: InternalAccount, rse_expression: str, *, session: "Session") -> int:
"""
Returns the global account limit for the account on the rse expression.
Expand All @@ -141,7 +143,7 @@ def get_global_account_limit(account, rse_expression, *, session: "Session"):


@read_session
def get_local_account_limit(account, rse_id, *, session: "Session"):
def get_local_account_limit(account: InternalAccount, rse_id: uuid.UUID, *, session: "Session") -> Optional[int | float]:
"""
Returns the account limit for the account on the rse.
Expand All @@ -162,7 +164,7 @@ def get_local_account_limit(account, rse_id, *, session: "Session"):


@read_session
def get_local_account_limits(account, rse_ids=None, *, session: "Session"):
def get_local_account_limits(account: InternalAccount, rse_ids: list[uuid.UUID] = None, *, session: "Session") -> dict:
"""
Returns the account limits for the account on the list of rses.
Expand Down Expand Up @@ -196,7 +198,7 @@ def get_local_account_limits(account, rse_ids=None, *, session: "Session"):


@transactional_session
def set_local_account_limit(account, rse_id, bytes_, *, session: "Session"):
def set_local_account_limit(account: InternalAccount, rse_id: uuid.UUID, bytes_: int, *, session: "Session"):
"""
Returns the limits for the account on the rse.
Expand All @@ -214,7 +216,7 @@ def set_local_account_limit(account, rse_id, bytes_, *, session: "Session"):


@transactional_session
def set_global_account_limit(account, rse_expression, bytes_, *, session: "Session"):
def set_global_account_limit(account: InternalAccount, rse_expression: str, bytes_: int, *, session: "Session"):
"""
Sets the global limit for the account on a RSE expression.
Expand All @@ -232,7 +234,7 @@ def set_global_account_limit(account, rse_expression, bytes_, *, session: "Sessi


@transactional_session
def delete_local_account_limit(account, rse_id, *, session: "Session"):
def delete_local_account_limit(account: InternalAccount, rse_id: uuid.UUID, *, session: "Session") -> bool:
"""
Deletes a local account limit.
Expand All @@ -250,7 +252,7 @@ def delete_local_account_limit(account, rse_id, *, session: "Session"):


@transactional_session
def delete_global_account_limit(account, rse_expression, *, session: "Session"):
def delete_global_account_limit(account: InternalAccount, rse_expression: str, *, session: "Session") -> bool:
"""
Deletes a global account limit.
Expand All @@ -268,7 +270,7 @@ def delete_global_account_limit(account, rse_expression, *, session: "Session"):


@transactional_session
def get_local_account_usage(account, rse_id=None, *, session: "Session"):
def get_local_account_usage(account: InternalAccount, rse_id: uuid.UUID = None, *, session: "Session") -> list[dict]:
"""
Read the account usage and connect it with (if available) the account limits of the account.
Expand Down Expand Up @@ -311,7 +313,7 @@ def get_local_account_usage(account, rse_id=None, *, session: "Session"):


@transactional_session
def get_global_account_usage(account, rse_expression=None, *, session: "Session"):
def get_global_account_usage(account: InternalAccount, rse_expression: str = None, *, session: "Session") -> list[dict]:
"""
Read the account usage and connect it with the global account limits of the account.
Expand Down

0 comments on commit 0997de5

Please sign in to comment.