Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing: Add type annotations to db/sqla/util #6734

Merged
merged 1 commit into from May 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
64 changes: 40 additions & 24 deletions lib/rucio/db/sqla/util.py
Expand Up @@ -16,7 +16,7 @@
from datetime import datetime
from hashlib import sha256
from os import urandom
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union

import sqlalchemy
from alembic import command, op
Expand All @@ -34,23 +34,26 @@
from rucio.common.cache import make_region_memcached
from rucio.common.config import config_get, config_get_list
from rucio.common.schema import get_schema_value
from rucio.common.types import InternalAccount
from rucio.common.types import InternalAccount, LoggerFunction
from rucio.common.utils import generate_uuid
from rucio.db.sqla import models
from rucio.db.sqla.constants import AccountStatus, AccountType, IdentityType
from rucio.db.sqla.session import get_dump_engine, get_engine, get_session
from rucio.db.sqla.types import InternalScopeString, String

if TYPE_CHECKING:
from typing import Optional, Union # noqa: F401
from collections.abc import Sequence

from sqlalchemy.engine import Inspector # noqa: F401
from sqlalchemy.orm import Session # noqa: F401
from sqlalchemy.engine import Inspector
from sqlalchemy.orm import Query, Session

# TypeVar representing the DeclarativeObj class defined inside _create_temp_table
DeclarativeObj = TypeVar('DeclarativeObj')

REGION = make_region_memcached(expiration_time=600, memcached_expire_time=3660)


def build_database():
def build_database() -> None:
""" Applies the schema to the database. Run this command once to build the database. """
engine = get_engine()

Expand All @@ -71,13 +74,13 @@ def build_database():
command.stamp(alembic_cfg, "head")


def dump_schema():
def dump_schema() -> None:
""" Creates a schema dump to a specific database. """
engine = get_dump_engine()
models.register_models(engine)


def destroy_database():
def destroy_database() -> None:
""" Removes the schema from the database. Only useful for test cases or malicious intents. """
engine = get_engine()

Expand All @@ -87,7 +90,7 @@ def destroy_database():
print('Cannot destroy schema -- assuming already gone, continuing:', e)


def drop_everything():
def drop_everything() -> None:
"""
Pre-gather all named constraints and table names, and drop everything.
This is better than using metadata.reflect(); metadata.drop_all()
Expand All @@ -101,7 +104,7 @@ def drop_everything():

with engine.connect() as conn:

inspector = inspect(conn) # type: Union[Inspector, PGInspector]
inspector: Union["Inspector", PGInspector] = inspect(conn)

for tname, fkcs in reversed(
inspector.get_sorted_table_and_fkc_names(schema='*')):
Expand All @@ -127,7 +130,7 @@ def drop_everything():
sqlalchemy.Enum(**enum).drop(bind=conn)


def create_base_vo():
def create_base_vo() -> None:
""" Creates the base VO """

session_scoped = get_session()
Expand All @@ -138,7 +141,7 @@ def create_base_vo():
s.add_all([vo])


def create_root_account():
def create_root_account() -> None:
"""
Inserts the default root account to an existing database. Make sure to change the default password later.
"""
Expand Down Expand Up @@ -216,7 +219,7 @@ def create_root_account():
s.commit()


def get_db_time():
def get_db_time() -> Optional[datetime]:
""" Gives the utc time on the db. """
session_scoped = get_session()
try:
Expand All @@ -241,7 +244,7 @@ def get_db_time():
session_scoped.remove()


def get_count(q):
def get_count(q: "Query") -> int:
"""
Fast way to get count in SQLAlchemy
Source: https://gist.github.com/hest/8798884
Expand All @@ -253,7 +256,7 @@ def get_count(q):
return count


def is_old_db():
def is_old_db() -> bool:
"""
Returns true, if alembic is used and the database is not on the
same revision as the code base.
Expand All @@ -274,7 +277,7 @@ def is_old_db():
return (len(query) != 0 and str(query[0].version_num) != alembicrevision.ALEMBIC_REVISION)


def json_implemented(*, session=None):
def json_implemented(*, session: Optional["Session"] = None) -> bool:
"""
Checks if the database on the current server installation can support json fields.

Expand All @@ -295,7 +298,7 @@ def json_implemented(*, session=None):
return True


def try_drop_constraint(constraint_name, table_name):
def try_drop_constraint(constraint_name: str, table_name: str) -> None:
"""
Tries to drop the given constrained and returns successfully if the
constraint already existed on Oracle databases.
Expand All @@ -309,7 +312,7 @@ def try_drop_constraint(constraint_name, table_name):
assert 'nonexistent constraint' in str(e)


def list_oracle_global_temp_tables(session):
def list_oracle_global_temp_tables(session: "Session") -> list[str]:
"""
Retrieve the list of global temporary tables in oracle
"""
Expand Down Expand Up @@ -338,7 +341,14 @@ def list_oracle_global_temp_tables(session):
return global_temp_tables


def _create_temp_table(name, *columns, primary_key=None, oracle_global_name=None, session=None, logger=logging.log):
def _create_temp_table(
name: str,
*columns: "Sequence[Column]",
primary_key: Optional["Sequence[Any]"] = None,
oracle_global_name: Optional[str] = None,
session: Optional["Session"] = None,
logger: LoggerFunction = logging.log
) -> type["DeclarativeObj"]:
"""
Create a temporary table with the given columns, register it into a declarative base, and return it.

Expand Down Expand Up @@ -452,18 +462,24 @@ class TempTableManager:
sessions in multiple threads at a time, so no need to protect indexes with a mutex.
"""

def __init__(self, session):
def __init__(self, session: "Session"):
self.session = session

self.next_idx_to_use = {}

def create_temp_table(self, name, *columns, primary_key=None, logger=logging.log):
def create_temp_table(
self,
name: str,
*columns: "Sequence[Column]",
primary_key: Optional["Sequence[Any]"] = None,
logger: LoggerFunction = logging.log
) -> type["DeclarativeObj"]:
idx = self.next_idx_to_use.setdefault(name, 0)
table = _create_temp_table(f'{name}_{idx}', *columns, primary_key=primary_key, session=self.session, logger=logger)
self.next_idx_to_use[name] = idx + 1
return table

def create_scope_name_table(self, logger=logging.log):
def create_scope_name_table(self, logger: LoggerFunction = logging.log) -> type["DeclarativeObj"]:
"""
Create a temporary table with columns 'scope' and 'name'
"""
Expand All @@ -479,7 +495,7 @@ def create_scope_name_table(self, logger=logging.log):
logger=logger,
)

def create_association_table(self, logger=logging.log):
def create_association_table(self, logger: LoggerFunction = logging.log) -> type["DeclarativeObj"]:
"""
Create a temporary table with columns 'scope', 'name', 'child_scope'and 'child_name'
"""
Expand All @@ -497,7 +513,7 @@ def create_association_table(self, logger=logging.log):
logger=logger,
)

def create_id_table(self, logger=logging.log):
def create_id_table(self, logger: LoggerFunction = logging.log) -> type["DeclarativeObj"]:
"""
Create a temp table with a single id column of uuid type
"""
Expand Down