Skip to content

Commit

Permalink
Testing: Refactor FilterEngine and add type annotations; rucio#6588
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio committed Mar 26, 2024
1 parent 4d2808d commit 2566841
Showing 1 changed file with 66 additions and 37 deletions.
103 changes: 66 additions & 37 deletions lib/rucio/core/did_meta_plugins/filter_engine.py
Expand Up @@ -17,10 +17,11 @@
import operator
from datetime import date, datetime, timedelta
from importlib import import_module
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union

import sqlalchemy
from sqlalchemy import and_, cast, or_
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql.expression import text

from rucio.common import exception
Expand All @@ -29,8 +30,14 @@
from rucio.db.sqla.session import read_session

if TYPE_CHECKING:
from collections.abc import Callable, Iterable

from sqlalchemy.orm import Session

from rucio.db.sqla.models import ModelBase

KeyType = TypeVar("KeyType", str, InstrumentedAttribute)
FilterTuple = tuple[KeyType, Callable[[object, object], Any], Union[bool, datetime, float, str]]

# lookup table converting keyword suffixes to pythonic operators.
OPERATORS_CONVERSION_LUT = {
Expand Down Expand Up @@ -76,25 +83,30 @@ class FilterEngine:
"""
An engine to provide advanced filtering functionality to DID listing requests.
"""
def __init__(self, filters, model_class=None, strict_coerce=True):
def __init__(
self,
filters: Union[str, dict[str, Any], list[dict[str, Any]]],
model_class: Optional[type["ModelBase"]] = None,
strict_coerce: bool = True
):
if isinstance(filters, str):
self._filters, _ = parse_did_filter_from_string_fe(filters, omit_name=True)
filters, _ = parse_did_filter_from_string_fe(filters, omit_name=True)
elif isinstance(filters, dict):
self._filters = [filters]
filters = [filters]
elif isinstance(filters, list):
self._filters = filters
filters = filters
else:
raise exception.DIDFilterSyntaxError("Input filters are of an unrecognised type.")

self._make_input_backwards_compatible()
self.mandatory_model_attributes = self._translate_filters(model_class=model_class, strict_coerce=strict_coerce)
filters = self._make_input_backwards_compatible(filters=filters)
self._filters, self.mandatory_model_attributes = self._translate_filters(filters=filters, model_class=model_class, strict_coerce=strict_coerce)
self._sanity_check_translated_filters()

@property
def filters(self):
def filters(self) -> list[list["FilterTuple"]]:
return self._filters

def _coerce_filter_word_to_model_attribute(self, word, model_class, strict=True):
def _coerce_filter_word_to_model_attribute(self, word: Any, model_class: Optional[type["ModelBase"]], strict: bool = True) -> Any:
"""
Attempts to coerce a filter word to an attribute of a <model_class>.
Expand All @@ -112,21 +124,22 @@ def _coerce_filter_word_to_model_attribute(self, word, model_class, strict=True)
raise exception.KeyNotFound("'{}' keyword could not be coerced to model class attribute. Attribute not found.".format(word))
return word

def _make_input_backwards_compatible(self):
def _make_input_backwards_compatible(self, filters: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Backwards compatibility for previous versions of filtering.
Does the following:
- converts "created_after" key to "created_at.gte"
- converts "created_before" key to "created_at.lte"
"""
for or_group in self._filters:
for or_group in filters:
if 'created_after' in or_group:
or_group['created_at.gte'] = or_group.pop('created_after')
elif 'created_before' in or_group:
or_group['created_at.lte'] = or_group.pop('created_before')
return filters

def _sanity_check_translated_filters(self):
def _sanity_check_translated_filters(self) -> None:
"""
Perform a few sanity checks on translated filters.
Expand All @@ -152,7 +165,7 @@ def _sanity_check_translated_filters(self):
raise ValueError("Name operator must be an equality operator.")
if key == 'length': # (3)
try:
int(value)
int(value) # type: ignore
except ValueError:
raise ValueError('Length has to be an integer value.')

Expand All @@ -169,7 +182,12 @@ def _sanity_check_translated_filters(self):
if len(set(or_group_test_duplicates)) != len(or_group_test_duplicates): # (6)
raise exception.DuplicateCriteriaInDIDFilter()

def _translate_filters(self, model_class, strict_coerce=True):
def _translate_filters(
self,
filters: "Iterable[dict[str, Any]]",
model_class: Optional[type["ModelBase"]],
strict_coerce: bool = True
) -> tuple[list[list["FilterTuple"]], list[InstrumentedAttribute[Any]]]:
"""
Reformats filters from:
Expand Down Expand Up @@ -198,9 +216,10 @@ def _translate_filters(self, model_class, strict_coerce=True):
Typecasting of values is also attempted.
:param filters: The filters to translate.
:param model_class: The SQL model class.
:param strict_coerce: Enforce that keywords must be coercable to a model attribute.
:returns: The set of mandatory model attributes to be used in the filter query.
:returns: The list of translated filters, and the set of mandatory model attributes to be used in the filter query.
:raises: MissingModuleException, DIDFilterSyntaxError
"""
if model_class:
Expand All @@ -211,7 +230,7 @@ def _translate_filters(self, model_class, strict_coerce=True):

mandatory_model_attributes = set()
filters_translated = []
for or_group in self._filters:
for or_group in filters:
and_group_parsed = []
for key, value in or_group.items():
# KEY
Expand Down Expand Up @@ -244,10 +263,9 @@ def _translate_filters(self, model_class, strict_coerce=True):
and_group_parsed.append(
(key_no_suffix, OPERATORS_CONVERSION_LUT.get(oper), value))
filters_translated.append(and_group_parsed)
self._filters = filters_translated
return list(mandatory_model_attributes)
return filters_translated, list(mandatory_model_attributes)

def _try_typecast_string(self, value):
def _try_typecast_string(self, value: str) -> Union[bool, datetime, float, str]:
"""
Check if string can be typecasted to bool, datetime or float.
Expand All @@ -258,11 +276,11 @@ def _try_typecast_string(self, value):
value = value.replace('false', 'False').replace('FALSE', 'False')
for format in VALID_DATE_FORMATS: # try parsing multiple date formats.
try:
value = datetime.strptime(value, format)
typecasted_value = datetime.strptime(value, format)
except ValueError:
continue
else:
return value
return typecasted_value
try:
operators = ('+', '-', '*', '/')
if not any(operator in value for operator in operators): # fix for lax ast literal_eval in earlier python versions
Expand All @@ -271,7 +289,7 @@ def _try_typecast_string(self, value):
pass
return value

def create_mongo_query(self, additional_filters={}):
def create_mongo_query(self, additional_filters: "Iterable[FilterTuple]" = []) -> dict[str, Any]:
"""
Returns a single mongo query describing the filters expression.
Expand All @@ -283,9 +301,9 @@ def create_mongo_query(self, additional_filters={}):
for filter in additional_filters:
or_group.append(list(filter)) # type: ignore

or_expressions = []
or_expressions: list[dict[str, Any]] = []
for or_group in self._filters:
and_expressions = []
and_expressions: list[dict[str, dict[str, Any]]] = []
for and_group in or_group:
key, oper, value = and_group
if isinstance(value, str) and any([char in value for char in ['*', '%']]): # wildcards
Expand Down Expand Up @@ -326,8 +344,12 @@ def create_mongo_query(self, additional_filters={}):

return query_str

def create_postgres_query(self, additional_filters={}, fixed_table_columns=('scope', 'name', 'vo'),
jsonb_column='data'):
def create_postgres_query(
self,
additional_filters: "Iterable[FilterTuple]" = [],
fixed_table_columns: tuple[str, ...] | dict[str, str] = ('scope', 'name', 'vo'),
jsonb_column: str = 'data'
) -> str:
"""
Returns a single postgres query describing the filters expression.
Expand All @@ -340,9 +362,9 @@ def create_postgres_query(self, additional_filters={}, fixed_table_columns=('sco
for _filter in additional_filters:
or_group.append(list(_filter)) # type: ignore

or_expressions = []
or_expressions: list[str] = []
for or_group in self._filters:
and_expressions = []
and_expressions: list[str] = []
for and_group in or_group:
key, oper, value = and_group
if key in fixed_table_columns: # is this key filtering on a column or in the jsonb?
Expand Down Expand Up @@ -400,7 +422,14 @@ def create_postgres_query(self, additional_filters={}, fixed_table_columns=('sco
return ' OR '.join(or_expressions)

@read_session
def create_sqla_query(self, *, session: "Session", additional_model_attributes=[], additional_filters={}, json_column=None):
def create_sqla_query(
self,
*,
session: "Session",
additional_model_attributes: list[InstrumentedAttribute[Any]] = [],
additional_filters: "Iterable[FilterTuple]" = [],
json_column: Optional[Any] = None
) -> Any:
"""
Returns a database query that fully describes the filters.
Expand All @@ -421,12 +450,12 @@ def create_sqla_query(self, *, session: "Session", additional_model_attributes=[
for _filter in additional_filters:
or_group.append(list(_filter)) # type: ignore

or_expressions = []
or_expressions: list = []
for or_group in self._filters:
and_expressions = []
for and_group in or_group:
key, oper, value = and_group
if isinstance(key, sqlalchemy.orm.attributes.InstrumentedAttribute): # -> this key filters on a table column.
if isinstance(key, InstrumentedAttribute): # -> this key filters on a table column.
if isinstance(value, str) and any([char in value for char in ['*', '%']]): # wildcards
if value in ('*', '%', '*', '%'): # match wildcard exactly == no filtering on key
continue
Expand Down Expand Up @@ -491,7 +520,7 @@ def create_sqla_query(self, *, session: "Session", additional_model_attributes=[
or_expressions.append(and_(*and_expressions))
return session.query(*all_model_attributes).filter(or_(*or_expressions))

def evaluate(self):
def evaluate(self) -> bool:
"""
Evaluates an expression and returns a boolean result.
Expand All @@ -506,7 +535,7 @@ def evaluate(self):
or_group_evaluations.append(all(and_group_evaluations))
return any(or_group_evaluations)

def print_filters(self):
def print_filters(self) -> str:
"""
A (more) human readable format of <filters>.
"""
Expand All @@ -516,16 +545,16 @@ def print_filters(self):
for or_group in self._filters:
for and_group in or_group:
key, oper, value = and_group
if isinstance(key, sqlalchemy.orm.attributes.InstrumentedAttribute):
if isinstance(key, InstrumentedAttribute):
key = and_group[0].key
if operators_conversion_LUT_inv[oper] == "":
oper = "eq"
else:
oper = operators_conversion_LUT_inv[oper]
if isinstance(value, sqlalchemy.orm.attributes.InstrumentedAttribute):
value = and_group[2].key
if isinstance(value, InstrumentedAttribute):
value = and_group[2].key # type: ignore
elif isinstance(value, DIDType):
value = and_group[2].name
value = and_group[2].name # type: ignore
filters = "{}{} {} {}".format(filters, key, oper, value)
if and_group != or_group[-1]:
filters += ' AND '
Expand Down

0 comments on commit 2566841

Please sign in to comment.