Skip to content

Commit

Permalink
Testing: Add type annotations to db/sqla/util; rucio#6588
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio committed Apr 25, 2024
1 parent 695c4bc commit 16981be
Showing 1 changed file with 39 additions and 23 deletions.
62 changes: 39 additions & 23 deletions lib/rucio/db/sqla/util.py
Expand Up @@ -16,7 +16,7 @@
from datetime import datetime
from hashlib import sha256
from os import urandom
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union

import sqlalchemy
from alembic import command, op
Expand All @@ -34,23 +34,26 @@
from rucio.common.cache import make_region_memcached
from rucio.common.config import config_get, config_get_list
from rucio.common.schema import get_schema_value
from rucio.common.types import InternalAccount
from rucio.common.types import InternalAccount, LoggerFunction
from rucio.common.utils import generate_uuid
from rucio.db.sqla import models
from rucio.db.sqla.constants import AccountStatus, AccountType, IdentityType
from rucio.db.sqla.session import get_dump_engine, get_engine, get_session
from rucio.db.sqla.types import InternalScopeString, String

if TYPE_CHECKING:
from typing import Optional, Union # noqa: F401
from collections.abc import Sequence

from sqlalchemy.engine import Inspector # noqa: F401
from sqlalchemy.orm import Session # noqa: F401
from sqlalchemy.engine import Inspector
from sqlalchemy.orm import Query, Session

# TypeVar representing the DeclarativeObj class defined inside _create_temp_table
DeclarativeObj = TypeVar('DeclarativeObj')

REGION = make_region_memcached(expiration_time=600, memcached_expire_time=3660)


def build_database():
def build_database() -> None:
""" Applies the schema to the database. Run this command once to build the database. """
engine = get_engine()

Expand All @@ -71,13 +74,13 @@ def build_database():
command.stamp(alembic_cfg, "head")


def dump_schema():
def dump_schema() -> None:
""" Creates a schema dump to a specific database. """
engine = get_dump_engine()
models.register_models(engine)


def destroy_database():
def destroy_database() -> None:
""" Removes the schema from the database. Only useful for test cases or malicious intents. """
engine = get_engine()

Expand All @@ -87,7 +90,7 @@ def destroy_database():
print('Cannot destroy schema -- assuming already gone, continuing:', e)


def drop_everything():
def drop_everything() -> None:
"""
Pre-gather all named constraints and table names, and drop everything.
This is better than using metadata.reflect(); metadata.drop_all()
Expand Down Expand Up @@ -127,7 +130,7 @@ def drop_everything():
sqlalchemy.Enum(**enum).drop(bind=conn)


def create_base_vo():
def create_base_vo() -> None:
""" Creates the base VO """

session_scoped = get_session()
Expand All @@ -138,7 +141,7 @@ def create_base_vo():
s.add_all([vo])


def create_root_account():
def create_root_account() -> None:
"""
Inserts the default root account to an existing database. Make sure to change the default password later.
"""
Expand Down Expand Up @@ -216,7 +219,7 @@ def create_root_account():
s.commit()


def get_db_time():
def get_db_time() -> Optional[datetime]:
""" Gives the utc time on the db. """
session_scoped = get_session()
try:
Expand All @@ -241,7 +244,7 @@ def get_db_time():
session_scoped.remove()


def get_count(q):
def get_count(q: "Query") -> int:
"""
Fast way to get count in SQLAlchemy
Source: https://gist.github.com/hest/8798884
Expand All @@ -253,7 +256,7 @@ def get_count(q):
return count


def is_old_db():
def is_old_db() -> bool:
"""
Returns true, if alembic is used and the database is not on the
same revision as the code base.
Expand All @@ -274,7 +277,7 @@ def is_old_db():
return (len(query) != 0 and str(query[0].version_num) != alembicrevision.ALEMBIC_REVISION)


def json_implemented(*, session=None):
def json_implemented(*, session: Optional["Session"] = None) -> bool:
"""
Checks if the database on the current server installation can support json fields.
Expand All @@ -295,7 +298,7 @@ def json_implemented(*, session=None):
return True


def try_drop_constraint(constraint_name, table_name):
def try_drop_constraint(constraint_name: str, table_name: str) -> None:
"""
Tries to drop the given constrained and returns successfully if the
constraint already existed on Oracle databases.
Expand All @@ -309,7 +312,7 @@ def try_drop_constraint(constraint_name, table_name):
assert 'nonexistent constraint' in str(e)


def list_oracle_global_temp_tables(session):
def list_oracle_global_temp_tables(session: "Session") -> list[str]:
"""
Retrieve the list of global temporary tables in oracle
"""
Expand Down Expand Up @@ -338,7 +341,14 @@ def list_oracle_global_temp_tables(session):
return global_temp_tables


def _create_temp_table(name, *columns, primary_key=None, oracle_global_name=None, session=None, logger=logging.log):
def _create_temp_table(
name: str,
*columns: "Sequence[Column]",
primary_key: Optional["Sequence[Any]"] = None,
oracle_global_name: Optional[str] = None,
session: Optional["Session"] = None,
logger: LoggerFunction = logging.log
) -> type["DeclarativeObj"]:
"""
Create a temporary table with the given columns, register it into a declarative base, and return it.
Expand Down Expand Up @@ -452,18 +462,24 @@ class TempTableManager:
sessions in multiple threads at a time, so no need to protect indexes with a mutex.
"""

def __init__(self, session):
def __init__(self, session: "Session"):
self.session = session

self.next_idx_to_use = {}

def create_temp_table(self, name, *columns, primary_key=None, logger=logging.log):
def create_temp_table(
self,
name: str,
*columns: "Sequence[Column]",
primary_key: Optional["Sequence[Any]"] = None,
logger: LoggerFunction = logging.log
) -> type["DeclarativeObj"]:
idx = self.next_idx_to_use.setdefault(name, 0)
table = _create_temp_table(f'{name}_{idx}', *columns, primary_key=primary_key, session=self.session, logger=logger)
self.next_idx_to_use[name] = idx + 1
return table

def create_scope_name_table(self, logger=logging.log):
def create_scope_name_table(self, logger: LoggerFunction = logging.log) -> type["DeclarativeObj"]:
"""
Create a temporary table with columns 'scope' and 'name'
"""
Expand All @@ -479,7 +495,7 @@ def create_scope_name_table(self, logger=logging.log):
logger=logger,
)

def create_association_table(self, logger=logging.log):
def create_association_table(self, logger: LoggerFunction = logging.log) -> type["DeclarativeObj"]:
"""
Create a temporary table with columns 'scope', 'name', 'child_scope'and 'child_name'
"""
Expand All @@ -497,7 +513,7 @@ def create_association_table(self, logger=logging.log):
logger=logger,
)

def create_id_table(self, logger=logging.log):
def create_id_table(self, logger: LoggerFunction = logging.log) -> type["DeclarativeObj"]:
"""
Create a temp table with a single id column of uuid type
"""
Expand Down

0 comments on commit 16981be

Please sign in to comment.