Skip to content

Commit

Permalink
Testing: Add type annotations to core/naming_convention; #6588
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio authored and bari12 committed Apr 26, 2024
1 parent 4b84dee commit c76e552
Showing 1 changed file with 40 additions and 9 deletions.
49 changes: 40 additions & 9 deletions lib/rucio/core/naming_convention.py
Expand Up @@ -14,7 +14,7 @@

from re import compile, error, match
from traceback import format_exc
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional, cast

from dogpile.cache.api import NO_VALUE
from sqlalchemy.exc import IntegrityError
Expand All @@ -26,13 +26,27 @@
from rucio.db.sqla.session import read_session, transactional_session

if TYPE_CHECKING:
from typing import TypedDict

from sqlalchemy.orm import Session

from rucio.common.types import InternalScope

class NamingConventionDict(TypedDict):
scope: InternalScope
regexp: str

REGION = make_region_memcached(expiration_time=900)


@transactional_session
def add_naming_convention(scope, regexp, convention_type, *, session: "Session"):
def add_naming_convention(
scope: "InternalScope",
regexp: str,
convention_type: KeyType,
*,
session: "Session"
) -> None:
"""
add a naming convention for a given scope
Expand All @@ -59,7 +73,12 @@ def add_naming_convention(scope, regexp, convention_type, *, session: "Session")


@read_session
def get_naming_convention(scope, convention_type, *, session: "Session"):
def get_naming_convention(
scope: "InternalScope",
convention_type: KeyType,
*,
session: "Session"
) -> Optional[str]:
"""
Get the naming convention for a given scope
Expand All @@ -77,7 +96,12 @@ def get_naming_convention(scope, convention_type, *, session: "Session"):


@transactional_session
def delete_naming_convention(scope, convention_type, *, session: "Session"):
def delete_naming_convention(
scope: "InternalScope",
convention_type: KeyType,
*,
session: "Session"
) -> int:
"""
delete a naming convention for a given scope
Expand All @@ -86,14 +110,15 @@ def delete_naming_convention(scope, convention_type, *, session: "Session"):
:param convention_type: the did_type on which the regexp should apply.
:param session: The database session in use.
"""
REGION.delete(scope.internal)
if scope.internal is not None:
REGION.delete(scope.internal)
return session.query(models.NamingConvention) \
.filter_by(scope=scope, convention_type=convention_type) \
.delete()


@read_session
def list_naming_conventions(*, session: "Session"):
def list_naming_conventions(*, session: "Session") -> list["NamingConventionDict"]:
"""
List all naming conventions.
Expand All @@ -103,11 +128,17 @@ def list_naming_conventions(*, session: "Session"):
"""
query = session.query(models.NamingConvention.scope,
models.NamingConvention.regexp)
return [row._asdict() for row in query]
return [cast("NamingConventionDict", row._asdict()) for row in query]


@read_session
def validate_name(scope, name, did_type, *, session: "Session"):
def validate_name(
scope: "InternalScope",
name: str,
did_type: str,
*,
session: "Session"
) -> Optional[dict[str, Any]]:
"""
Validate a name according to a naming convention.
Expand Down Expand Up @@ -136,7 +167,7 @@ def validate_name(scope, name, did_type, *, session: "Session"):
return

# Validate with regexp
groups = match(regexp, str(name))
groups = match(regexp, str(name)) # type: ignore
if groups:
meta = groups.groupdict()
# Hack to get task_id from version
Expand Down

0 comments on commit c76e552

Please sign in to comment.