Skip to content

Commit

Permalink
Testing: Add type annotations to models.py; rucio#6588
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio committed Mar 23, 2024
1 parent bed3f05 commit 13b79a9
Showing 1 changed file with 29 additions and 18 deletions.
47 changes: 29 additions & 18 deletions lib/rucio/db/sqla/models.py
Expand Up @@ -14,7 +14,7 @@

import uuid
from datetime import datetime, timedelta
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union

from sqlalchemy import BigInteger, Boolean, DateTime, Enum, Float, Integer, SmallInteger, String, Text, UniqueConstraint, event
from sqlalchemy.engine import Engine
Expand Down Expand Up @@ -54,6 +54,11 @@
from rucio.db.sqla.session import BASE
from rucio.db.sqla.types import GUID, JSON, BooleanString, InternalAccountString, InternalScopeString

if TYPE_CHECKING:
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import Session
from sqlalchemy.sql import Insert, Update

# SQLAlchemy defines the corresponding code behind TYPE_CHECKING
# https://github.com/sqlalchemy/sqlalchemy/blob/d9acd6223299c118464d30abfa483e26a536239d/lib/sqlalchemy/orm/base.py#L814
# And pylint/astroid don't have an option to evaluate this code
Expand All @@ -64,24 +69,24 @@


@compiles(Boolean, "oracle")
def compile_binary_oracle(type_, compiler, **kw):
def compile_binary_oracle(type_, compiler, **kw) -> str:
return "NUMBER(1)"


@event.listens_for(Table, "before_create")
def _mysql_rename_type(target, connection, **kw):
def _mysql_rename_type(target: Table, connection: "Connection", **kw) -> None:
if connection.dialect.name == 'mysql' and target.name == 'quarantined_replicas':
target.columns.path.type = String(255)


@event.listens_for(Table, "before_create")
def _psql_rename_type(target, connection, **kw):
def _psql_rename_type(target: Table, connection: "Connection", **kw) -> None:
if connection.dialect.name == 'postgresql' and target.name == 'account_map':
target.columns.identity_type.type.name = 'IDENTITIES_TYPE_CHK'


@event.listens_for(Table, "before_create")
def _oracle_json_constraint(target, connection, **kw):
def _oracle_json_constraint(target: Table, connection: "Connection", **kw) -> None:
if connection.dialect.name == 'oracle':
try:
oracle_version = int(connection.connection.version.split('.')[0])
Expand All @@ -95,7 +100,13 @@ def _oracle_json_constraint(target, connection, **kw):


@event.listens_for(Engine, "before_execute", retval=True)
def _add_hint(conn, element, multiparams, params, execution_options):
def _add_hint(
conn: "Connection",
element: Union[Delete, "Insert", "Update"],
multiparams,
params,
execution_options
) -> tuple[Union[Delete, "Insert", "Update"], list, dict]:
if conn.dialect.name == 'oracle' and isinstance(element, Delete) and element.table.name == 'locks':
element = element.prefix_with("/*+ INDEX(LOCKS LOCKS_PK) */")
if conn.dialect.name == 'oracle' and isinstance(element, Delete) and element.table.name == 'replicas':
Expand All @@ -110,15 +121,15 @@ def _add_hint(conn, element, multiparams, params, execution_options):


@event.listens_for(PrimaryKeyConstraint, "after_parent_attach")
def _pk_constraint_name(const, table):
def _pk_constraint_name(const: PrimaryKeyConstraint, table: Table) -> None:
if table.name.upper() == 'QUARANTINED_REPLICAS_HISTORY':
const.name = "QRD_REPLICAS_HISTORY_PK"
else:
const.name = "%s_PK" % (table.name.upper(),)


@event.listens_for(ForeignKeyConstraint, "after_parent_attach")
def _fk_constraint_name(const, table):
def _fk_constraint_name(const: ForeignKeyConstraint, table: Table) -> None:
if const.name:
return
fk = const.elements[0]
Expand All @@ -129,14 +140,14 @@ def _fk_constraint_name(const, table):


@event.listens_for(UniqueConstraint, "after_parent_attach")
def _unique_constraint_name(const, table):
def _unique_constraint_name(const: UniqueConstraint, table: Table) -> None:
if const.name:
return
const.name = "uq_%s_%s" % (table.name, list(const.columns)[0].name)


@event.listens_for(CheckConstraint, "after_parent_attach")
def _ck_constraint_name(const, table):
def _ck_constraint_name(const: CheckConstraint, table: Table) -> None:

if const.name is None:
if 'DELETED' in str(const.sqltext).upper():
Expand Down Expand Up @@ -219,23 +230,23 @@ def created_at(cls): # pylint: disable=no-self-argument
def updated_at(cls): # pylint: disable=no-self-argument
return mapped_column("updated_at", DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

def save(self, flush=True, session=None):
def save(self, flush: bool = True, session: Optional["Session"] = None) -> None:
"""Save this object"""
# Sessions created with autoflush=True be default since sqlAlchemy 1.4.
# So explicatly calling session.flush is not necessary.
# Sessions created with autoflush: bool = True be default since sqlAlchemy 1.4.
# So explicitly calling session.flush is not necessary.
# However, when autogenerated primary keys involved, calling
# session.flush to get the id from DB.
session.add(self)
if flush:
session.flush()

def delete(self, flush=True, session=None):
def delete(self, flush: bool = True, session: Optional["Session"] = None) -> None:
"""Delete this object"""
session.delete(self)
if flush:
session.flush()

def update(self, values, flush=True, session=None):
def update(self, values: dict[str, Any], flush: bool = True, session: Optional["Session"] = None) -> None:
"""dict.update() behaviour."""
for k, v in values.items():
self[k] = v
Expand Down Expand Up @@ -293,7 +304,7 @@ def deleted(cls): # pylint: disable=no-self-argument
def deleted_at(cls): # pylint: disable=no-self-argument
return mapped_column("deleted_at", DateTime)

def delete(self, flush=True, session=None):
def delete(self, flush: bool = True, session: Optional["Session"] = None) -> None:
"""Delete this object"""
self.deleted = True
self.deleted_at = datetime.utcnow()
Expand Down Expand Up @@ -1690,7 +1701,7 @@ class FollowEvents(BASE, ModelBase):
Index('DIDS_FOLLOWED_EVENTS_ACC_IDX', 'account'))


def register_models(engine):
def register_models(engine: Engine) -> None:
"""
Creates database tables for all models with the given engine
"""
Expand Down Expand Up @@ -1761,7 +1772,7 @@ def register_models(engine):
model.metadata.create_all(engine) # pylint: disable=maybe-no-member


def unregister_models(engine):
def unregister_models(engine: Engine) -> None:
"""
Drops database tables for all models with the given engine
"""
Expand Down

0 comments on commit 13b79a9

Please sign in to comment.