diff --git a/lib/rucio/core/did_meta_plugins/filter_engine.py b/lib/rucio/core/did_meta_plugins/filter_engine.py index b6a7e55db8..57f7c572fc 100644 --- a/lib/rucio/core/did_meta_plugins/filter_engine.py +++ b/lib/rucio/core/did_meta_plugins/filter_engine.py @@ -17,10 +17,11 @@ import operator from datetime import date, datetime, timedelta from importlib import import_module -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union import sqlalchemy from sqlalchemy import and_, cast, or_ +from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql.expression import text from rucio.common import exception @@ -29,8 +30,14 @@ from rucio.db.sqla.session import read_session if TYPE_CHECKING: + from collections.abc import Callable, Iterable + from sqlalchemy.orm import Session + from rucio.db.sqla.models import ModelBase + + KeyType = TypeVar("KeyType", str, InstrumentedAttribute) + FilterTuple = tuple[KeyType, Callable[[object, object], Any], Union[bool, datetime, float, str]] # lookup table converting keyword suffixes to pythonic operators. OPERATORS_CONVERSION_LUT = { @@ -76,25 +83,30 @@ class FilterEngine: """ An engine to provide advanced filtering functionality to DID listing requests. """ - def __init__(self, filters, model_class=None, strict_coerce=True): + def __init__( + self, + filters: Union[str, dict[str, Any], list[dict[str, Any]]], + model_class: Optional[type["ModelBase"]] = None, + strict_coerce: bool = True + ): if isinstance(filters, str): - self._filters, _ = parse_did_filter_from_string_fe(filters, omit_name=True) + filters, _ = parse_did_filter_from_string_fe(filters, omit_name=True) elif isinstance(filters, dict): - self._filters = [filters] + filters = [filters] elif isinstance(filters, list): - self._filters = filters + filters = filters else: raise exception.DIDFilterSyntaxError("Input filters are of an unrecognised type.") - self._make_input_backwards_compatible() - self.mandatory_model_attributes = self._translate_filters(model_class=model_class, strict_coerce=strict_coerce) + filters = self._make_input_backwards_compatible(filters=filters) + self._filters, self.mandatory_model_attributes = self._translate_filters(filters=filters, model_class=model_class, strict_coerce=strict_coerce) self._sanity_check_translated_filters() @property - def filters(self): + def filters(self) -> list[list["FilterTuple"]]: return self._filters - def _coerce_filter_word_to_model_attribute(self, word, model_class, strict=True): + def _coerce_filter_word_to_model_attribute(self, word: Any, model_class: Optional["ModelBase"], strict: bool = True) -> Any: """ Attempts to coerce a filter word to an attribute of a . @@ -112,7 +124,7 @@ def _coerce_filter_word_to_model_attribute(self, word, model_class, strict=True) raise exception.KeyNotFound("'{}' keyword could not be coerced to model class attribute. Attribute not found.".format(word)) return word - def _make_input_backwards_compatible(self): + def _make_input_backwards_compatible(self, filters: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Backwards compatibility for previous versions of filtering. @@ -120,13 +132,14 @@ def _make_input_backwards_compatible(self): - converts "created_after" key to "created_at.gte" - converts "created_before" key to "created_at.lte" """ - for or_group in self._filters: + for or_group in filters: if 'created_after' in or_group: or_group['created_at.gte'] = or_group.pop('created_after') elif 'created_before' in or_group: or_group['created_at.lte'] = or_group.pop('created_before') + return filters - def _sanity_check_translated_filters(self): + def _sanity_check_translated_filters(self) -> None: """ Perform a few sanity checks on translated filters. @@ -152,7 +165,7 @@ def _sanity_check_translated_filters(self): raise ValueError("Name operator must be an equality operator.") if key == 'length': # (3) try: - int(value) + int(value) # type: ignore except ValueError: raise ValueError('Length has to be an integer value.') @@ -169,7 +182,12 @@ def _sanity_check_translated_filters(self): if len(set(or_group_test_duplicates)) != len(or_group_test_duplicates): # (6) raise exception.DuplicateCriteriaInDIDFilter() - def _translate_filters(self, model_class, strict_coerce=True): + def _translate_filters( + self, + filters: "Iterable[dict[str, Any]]", + model_class: Optional[ModelBase], + strict_coerce: bool = True + ) -> tuple[list[list["FilterTuple"]], list[InstrumentedAttribute[Any]]]: """ Reformats filters from: @@ -198,9 +216,10 @@ def _translate_filters(self, model_class, strict_coerce=True): Typecasting of values is also attempted. + :param filters: The filters to translate. :param model_class: The SQL model class. :param strict_coerce: Enforce that keywords must be coercable to a model attribute. - :returns: The set of mandatory model attributes to be used in the filter query. + :returns: The list of translated filters, and the set of mandatory model attributes to be used in the filter query. :raises: MissingModuleException, DIDFilterSyntaxError """ if model_class: @@ -211,7 +230,7 @@ def _translate_filters(self, model_class, strict_coerce=True): mandatory_model_attributes = set() filters_translated = [] - for or_group in self._filters: + for or_group in filters: and_group_parsed = [] for key, value in or_group.items(): # KEY @@ -244,10 +263,9 @@ def _translate_filters(self, model_class, strict_coerce=True): and_group_parsed.append( (key_no_suffix, OPERATORS_CONVERSION_LUT.get(oper), value)) filters_translated.append(and_group_parsed) - self._filters = filters_translated - return list(mandatory_model_attributes) + return filters_translated, list(mandatory_model_attributes) - def _try_typecast_string(self, value): + def _try_typecast_string(self, value: str) -> Union[bool, datetime, float, str]: """ Check if string can be typecasted to bool, datetime or float. @@ -258,11 +276,11 @@ def _try_typecast_string(self, value): value = value.replace('false', 'False').replace('FALSE', 'False') for format in VALID_DATE_FORMATS: # try parsing multiple date formats. try: - value = datetime.strptime(value, format) + typecasted_value = datetime.strptime(value, format) except ValueError: continue else: - return value + return typecasted_value try: operators = ('+', '-', '*', '/') if not any(operator in value for operator in operators): # fix for lax ast literal_eval in earlier python versions @@ -271,7 +289,7 @@ def _try_typecast_string(self, value): pass return value - def create_mongo_query(self, additional_filters={}): + def create_mongo_query(self, additional_filters: "Iterable[FilterTuple]" = []) -> dict[str, Any]: """ Returns a single mongo query describing the filters expression. @@ -283,9 +301,9 @@ def create_mongo_query(self, additional_filters={}): for filter in additional_filters: or_group.append(list(filter)) # type: ignore - or_expressions = [] + or_expressions: list[dict[str, Any]] = [] for or_group in self._filters: - and_expressions = [] + and_expressions: list[dict[str, dict[str, Any]]] = [] for and_group in or_group: key, oper, value = and_group if isinstance(value, str) and any([char in value for char in ['*', '%']]): # wildcards @@ -326,8 +344,12 @@ def create_mongo_query(self, additional_filters={}): return query_str - def create_postgres_query(self, additional_filters={}, fixed_table_columns=('scope', 'name', 'vo'), - jsonb_column='data'): + def create_postgres_query( + self, + additional_filters: "Iterable[FilterTuple]" = [], + fixed_table_columns: tuple[str, ...] | dict[str, str] = ('scope', 'name', 'vo'), + jsonb_column: str = 'data' + ) -> str: """ Returns a single postgres query describing the filters expression. @@ -340,9 +362,9 @@ def create_postgres_query(self, additional_filters={}, fixed_table_columns=('sco for _filter in additional_filters: or_group.append(list(_filter)) # type: ignore - or_expressions = [] + or_expressions: list[str] = [] for or_group in self._filters: - and_expressions = [] + and_expressions: list[str] = [] for and_group in or_group: key, oper, value = and_group if key in fixed_table_columns: # is this key filtering on a column or in the jsonb? @@ -400,7 +422,14 @@ def create_postgres_query(self, additional_filters={}, fixed_table_columns=('sco return ' OR '.join(or_expressions) @read_session - def create_sqla_query(self, *, session: "Session", additional_model_attributes=[], additional_filters={}, json_column=None): + def create_sqla_query( + self, + *, + session: "Session", + additional_model_attributes: list[InstrumentedAttribute[Any]] = [], + additional_filters: "Iterable[FilterTuple]" = [], + json_column: Optional[Any] = None + ) -> Any: """ Returns a database query that fully describes the filters. @@ -421,12 +450,12 @@ def create_sqla_query(self, *, session: "Session", additional_model_attributes=[ for _filter in additional_filters: or_group.append(list(_filter)) # type: ignore - or_expressions = [] + or_expressions: list = [] for or_group in self._filters: and_expressions = [] for and_group in or_group: key, oper, value = and_group - if isinstance(key, sqlalchemy.orm.attributes.InstrumentedAttribute): # -> this key filters on a table column. + if isinstance(key, InstrumentedAttribute): # -> this key filters on a table column. if isinstance(value, str) and any([char in value for char in ['*', '%']]): # wildcards if value in ('*', '%', '*', '%'): # match wildcard exactly == no filtering on key continue @@ -491,7 +520,7 @@ def create_sqla_query(self, *, session: "Session", additional_model_attributes=[ or_expressions.append(and_(*and_expressions)) return session.query(*all_model_attributes).filter(or_(*or_expressions)) - def evaluate(self): + def evaluate(self) -> bool: """ Evaluates an expression and returns a boolean result. @@ -506,7 +535,7 @@ def evaluate(self): or_group_evaluations.append(all(and_group_evaluations)) return any(or_group_evaluations) - def print_filters(self): + def print_filters(self) -> str: """ A (more) human readable format of . """ @@ -516,16 +545,16 @@ def print_filters(self): for or_group in self._filters: for and_group in or_group: key, oper, value = and_group - if isinstance(key, sqlalchemy.orm.attributes.InstrumentedAttribute): + if isinstance(key, InstrumentedAttribute): key = and_group[0].key if operators_conversion_LUT_inv[oper] == "": oper = "eq" else: oper = operators_conversion_LUT_inv[oper] - if isinstance(value, sqlalchemy.orm.attributes.InstrumentedAttribute): - value = and_group[2].key + if isinstance(value, InstrumentedAttribute): + value = and_group[2].key # type: ignore elif isinstance(value, DIDType): - value = and_group[2].name + value = and_group[2].name # type: ignore filters = "{}{} {} {}".format(filters, key, oper, value) if and_group != or_group[-1]: filters += ' AND '