From 16981be12bc7a4a05f4e2832acf4d04b350a9b9c Mon Sep 17 00:00:00 2001 From: rdimaio Date: Thu, 25 Apr 2024 17:14:42 +0200 Subject: [PATCH] Testing: Add type annotations to db/sqla/util; #6588 --- lib/rucio/db/sqla/util.py | 62 ++++++++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/lib/rucio/db/sqla/util.py b/lib/rucio/db/sqla/util.py index b4e2e85f78..9ff44c0174 100644 --- a/lib/rucio/db/sqla/util.py +++ b/lib/rucio/db/sqla/util.py @@ -16,7 +16,7 @@ from datetime import datetime from hashlib import sha256 from os import urandom -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union import sqlalchemy from alembic import command, op @@ -34,7 +34,7 @@ from rucio.common.cache import make_region_memcached from rucio.common.config import config_get, config_get_list from rucio.common.schema import get_schema_value -from rucio.common.types import InternalAccount +from rucio.common.types import InternalAccount, LoggerFunction from rucio.common.utils import generate_uuid from rucio.db.sqla import models from rucio.db.sqla.constants import AccountStatus, AccountType, IdentityType @@ -42,15 +42,18 @@ from rucio.db.sqla.types import InternalScopeString, String if TYPE_CHECKING: - from typing import Optional, Union # noqa: F401 + from collections.abc import Sequence - from sqlalchemy.engine import Inspector # noqa: F401 - from sqlalchemy.orm import Session # noqa: F401 + from sqlalchemy.engine import Inspector + from sqlalchemy.orm import Query, Session + + # TypeVar representing the DeclarativeObj class defined inside _create_temp_table + DeclarativeObj = TypeVar('DeclarativeObj') REGION = make_region_memcached(expiration_time=600, memcached_expire_time=3660) -def build_database(): +def build_database() -> None: """ Applies the schema to the database. Run this command once to build the database. """ engine = get_engine() @@ -71,13 +74,13 @@ def build_database(): command.stamp(alembic_cfg, "head") -def dump_schema(): +def dump_schema() -> None: """ Creates a schema dump to a specific database. """ engine = get_dump_engine() models.register_models(engine) -def destroy_database(): +def destroy_database() -> None: """ Removes the schema from the database. Only useful for test cases or malicious intents. """ engine = get_engine() @@ -87,7 +90,7 @@ def destroy_database(): print('Cannot destroy schema -- assuming already gone, continuing:', e) -def drop_everything(): +def drop_everything() -> None: """ Pre-gather all named constraints and table names, and drop everything. This is better than using metadata.reflect(); metadata.drop_all() @@ -127,7 +130,7 @@ def drop_everything(): sqlalchemy.Enum(**enum).drop(bind=conn) -def create_base_vo(): +def create_base_vo() -> None: """ Creates the base VO """ session_scoped = get_session() @@ -138,7 +141,7 @@ def create_base_vo(): s.add_all([vo]) -def create_root_account(): +def create_root_account() -> None: """ Inserts the default root account to an existing database. Make sure to change the default password later. """ @@ -216,7 +219,7 @@ def create_root_account(): s.commit() -def get_db_time(): +def get_db_time() -> Optional[datetime]: """ Gives the utc time on the db. """ session_scoped = get_session() try: @@ -241,7 +244,7 @@ def get_db_time(): session_scoped.remove() -def get_count(q): +def get_count(q: "Query") -> int: """ Fast way to get count in SQLAlchemy Source: https://gist.github.com/hest/8798884 @@ -253,7 +256,7 @@ def get_count(q): return count -def is_old_db(): +def is_old_db() -> bool: """ Returns true, if alembic is used and the database is not on the same revision as the code base. @@ -274,7 +277,7 @@ def is_old_db(): return (len(query) != 0 and str(query[0].version_num) != alembicrevision.ALEMBIC_REVISION) -def json_implemented(*, session=None): +def json_implemented(*, session: Optional["Session"] = None) -> bool: """ Checks if the database on the current server installation can support json fields. @@ -295,7 +298,7 @@ def json_implemented(*, session=None): return True -def try_drop_constraint(constraint_name, table_name): +def try_drop_constraint(constraint_name: str, table_name: str) -> None: """ Tries to drop the given constrained and returns successfully if the constraint already existed on Oracle databases. @@ -309,7 +312,7 @@ def try_drop_constraint(constraint_name, table_name): assert 'nonexistent constraint' in str(e) -def list_oracle_global_temp_tables(session): +def list_oracle_global_temp_tables(session: "Session") -> list[str]: """ Retrieve the list of global temporary tables in oracle """ @@ -338,7 +341,14 @@ def list_oracle_global_temp_tables(session): return global_temp_tables -def _create_temp_table(name, *columns, primary_key=None, oracle_global_name=None, session=None, logger=logging.log): +def _create_temp_table( + name: str, + *columns: "Sequence[Column]", + primary_key: Optional["Sequence[Any]"] = None, + oracle_global_name: Optional[str] = None, + session: Optional["Session"] = None, + logger: LoggerFunction = logging.log +) -> type["DeclarativeObj"]: """ Create a temporary table with the given columns, register it into a declarative base, and return it. @@ -452,18 +462,24 @@ class TempTableManager: sessions in multiple threads at a time, so no need to protect indexes with a mutex. """ - def __init__(self, session): + def __init__(self, session: "Session"): self.session = session self.next_idx_to_use = {} - def create_temp_table(self, name, *columns, primary_key=None, logger=logging.log): + def create_temp_table( + self, + name: str, + *columns: "Sequence[Column]", + primary_key: Optional["Sequence[Any]"] = None, + logger: LoggerFunction = logging.log + ) -> type["DeclarativeObj"]: idx = self.next_idx_to_use.setdefault(name, 0) table = _create_temp_table(f'{name}_{idx}', *columns, primary_key=primary_key, session=self.session, logger=logger) self.next_idx_to_use[name] = idx + 1 return table - def create_scope_name_table(self, logger=logging.log): + def create_scope_name_table(self, logger: LoggerFunction = logging.log) -> type["DeclarativeObj"]: """ Create a temporary table with columns 'scope' and 'name' """ @@ -479,7 +495,7 @@ def create_scope_name_table(self, logger=logging.log): logger=logger, ) - def create_association_table(self, logger=logging.log): + def create_association_table(self, logger: LoggerFunction = logging.log) -> type["DeclarativeObj"]: """ Create a temporary table with columns 'scope', 'name', 'child_scope'and 'child_name' """ @@ -497,7 +513,7 @@ def create_association_table(self, logger=logging.log): logger=logger, ) - def create_id_table(self, logger=logging.log): + def create_id_table(self, logger: LoggerFunction = logging.log) -> type["DeclarativeObj"]: """ Create a temp table with a single id column of uuid type """