From 13b79a9a3c0e004baf02ac49cf7331073970ca59 Mon Sep 17 00:00:00 2001 From: rdimaio Date: Sat, 23 Mar 2024 16:46:26 +0100 Subject: [PATCH] Testing: Add type annotations to models.py; #6588 --- lib/rucio/db/sqla/models.py | 47 +++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/lib/rucio/db/sqla/models.py b/lib/rucio/db/sqla/models.py index c811714843..720b2762be 100644 --- a/lib/rucio/db/sqla/models.py +++ b/lib/rucio/db/sqla/models.py @@ -14,7 +14,7 @@ import uuid from datetime import datetime, timedelta -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from sqlalchemy import BigInteger, Boolean, DateTime, Enum, Float, Integer, SmallInteger, String, Text, UniqueConstraint, event from sqlalchemy.engine import Engine @@ -54,6 +54,11 @@ from rucio.db.sqla.session import BASE from rucio.db.sqla.types import GUID, JSON, BooleanString, InternalAccountString, InternalScopeString +if TYPE_CHECKING: + from sqlalchemy.engine.base import Connection + from sqlalchemy.orm import Session + from sqlalchemy.sql import Insert, Update + # SQLAlchemy defines the corresponding code behind TYPE_CHECKING # https://github.com/sqlalchemy/sqlalchemy/blob/d9acd6223299c118464d30abfa483e26a536239d/lib/sqlalchemy/orm/base.py#L814 # And pylint/astroid don't have an option to evaluate this code @@ -64,24 +69,24 @@ @compiles(Boolean, "oracle") -def compile_binary_oracle(type_, compiler, **kw): +def compile_binary_oracle(type_, compiler, **kw) -> str: return "NUMBER(1)" @event.listens_for(Table, "before_create") -def _mysql_rename_type(target, connection, **kw): +def _mysql_rename_type(target: Table, connection: "Connection", **kw) -> None: if connection.dialect.name == 'mysql' and target.name == 'quarantined_replicas': target.columns.path.type = String(255) @event.listens_for(Table, "before_create") -def _psql_rename_type(target, connection, **kw): +def _psql_rename_type(target: Table, connection: "Connection", **kw) -> None: if connection.dialect.name == 'postgresql' and target.name == 'account_map': target.columns.identity_type.type.name = 'IDENTITIES_TYPE_CHK' @event.listens_for(Table, "before_create") -def _oracle_json_constraint(target, connection, **kw): +def _oracle_json_constraint(target: Table, connection: "Connection", **kw) -> None: if connection.dialect.name == 'oracle': try: oracle_version = int(connection.connection.version.split('.')[0]) @@ -95,7 +100,13 @@ def _oracle_json_constraint(target, connection, **kw): @event.listens_for(Engine, "before_execute", retval=True) -def _add_hint(conn, element, multiparams, params, execution_options): +def _add_hint( + conn: "Connection", + element: Union[Delete, "Insert", "Update"], + multiparams, + params, + execution_options +) -> tuple[Union[Delete, "Insert", "Update"], list, dict]: if conn.dialect.name == 'oracle' and isinstance(element, Delete) and element.table.name == 'locks': element = element.prefix_with("/*+ INDEX(LOCKS LOCKS_PK) */") if conn.dialect.name == 'oracle' and isinstance(element, Delete) and element.table.name == 'replicas': @@ -110,7 +121,7 @@ def _add_hint(conn, element, multiparams, params, execution_options): @event.listens_for(PrimaryKeyConstraint, "after_parent_attach") -def _pk_constraint_name(const, table): +def _pk_constraint_name(const: PrimaryKeyConstraint, table: Table) -> None: if table.name.upper() == 'QUARANTINED_REPLICAS_HISTORY': const.name = "QRD_REPLICAS_HISTORY_PK" else: @@ -118,7 +129,7 @@ def _pk_constraint_name(const, table): @event.listens_for(ForeignKeyConstraint, "after_parent_attach") -def _fk_constraint_name(const, table): +def _fk_constraint_name(const: ForeignKeyConstraint, table: Table) -> None: if const.name: return fk = const.elements[0] @@ -129,14 +140,14 @@ def _fk_constraint_name(const, table): @event.listens_for(UniqueConstraint, "after_parent_attach") -def _unique_constraint_name(const, table): +def _unique_constraint_name(const: UniqueConstraint, table: Table) -> None: if const.name: return const.name = "uq_%s_%s" % (table.name, list(const.columns)[0].name) @event.listens_for(CheckConstraint, "after_parent_attach") -def _ck_constraint_name(const, table): +def _ck_constraint_name(const: CheckConstraint, table: Table) -> None: if const.name is None: if 'DELETED' in str(const.sqltext).upper(): @@ -219,23 +230,23 @@ def created_at(cls): # pylint: disable=no-self-argument def updated_at(cls): # pylint: disable=no-self-argument return mapped_column("updated_at", DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - def save(self, flush=True, session=None): + def save(self, flush: bool = True, session: Optional["Session"] = None) -> None: """Save this object""" - # Sessions created with autoflush=True be default since sqlAlchemy 1.4. - # So explicatly calling session.flush is not necessary. + # Sessions created with autoflush: bool = True be default since sqlAlchemy 1.4. + # So explicitly calling session.flush is not necessary. # However, when autogenerated primary keys involved, calling # session.flush to get the id from DB. session.add(self) if flush: session.flush() - def delete(self, flush=True, session=None): + def delete(self, flush: bool = True, session: Optional["Session"] = None) -> None: """Delete this object""" session.delete(self) if flush: session.flush() - def update(self, values, flush=True, session=None): + def update(self, values: dict[str, Any], flush: bool = True, session: Optional["Session"] = None) -> None: """dict.update() behaviour.""" for k, v in values.items(): self[k] = v @@ -293,7 +304,7 @@ def deleted(cls): # pylint: disable=no-self-argument def deleted_at(cls): # pylint: disable=no-self-argument return mapped_column("deleted_at", DateTime) - def delete(self, flush=True, session=None): + def delete(self, flush: bool = True, session: Optional["Session"] = None) -> None: """Delete this object""" self.deleted = True self.deleted_at = datetime.utcnow() @@ -1690,7 +1701,7 @@ class FollowEvents(BASE, ModelBase): Index('DIDS_FOLLOWED_EVENTS_ACC_IDX', 'account')) -def register_models(engine): +def register_models(engine: Engine) -> None: """ Creates database tables for all models with the given engine """ @@ -1761,7 +1772,7 @@ def register_models(engine): model.metadata.create_all(engine) # pylint: disable=maybe-no-member -def unregister_models(engine): +def unregister_models(engine: Engine) -> None: """ Drops database tables for all models with the given engine """