diff --git a/Makefile b/Makefile index 58f6a24..635e2cb 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,7 @@ docs: html-docs $(OPEN_COMMAND) docs/build/html/index.html test tests: - @$(VENV)/bin/pytest --cov=sure tests/test_runtime/test_scenario_result.py + @$(VENV)/bin/pytest --cov=sure tests/unit/test_astuneval.py @$(VENV)/bin/pytest --cov=sure tests # runs main command-line tool diff --git a/sure/astuneval.py b/sure/astuneval.py new file mode 100644 index 0000000..9b81e31 --- /dev/null +++ b/sure/astuneval.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) <2010-2024> Gabriel Falcão +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +"""astuneval (Abstract Syntax-Tree Unevaluation) - safe substitution for unsafe :func:`eval` +""" +import ast + + +class Accessor(object): + """base class for object element accessors""" + + def __init__(self, astbody): + self.body = astbody + + def __call__(self, object: object, *args, **kw) -> object: + return self.access(object, *args, **kw) + + def access(self, object: object) -> object: + raise NotImplementedError(f"support to {type(self.body)} is not implemented") + + +class NameAccessor(Accessor): + """Accesses an object's attributes through name""" + + def access(self, object: object) -> object: + return getattr(object, self.body.id) + + +class SliceAccessor(Accessor): + """Accesses an object's attributes through slice""" + + def access(self, object: object) -> object: + return object[self.body.value] + + +class SubsAccessor(Accessor): + """Accesses an object's attributes through subscript""" + + def access(self, object: object) -> object: + get_value = NameAccessor(self.body.value) + get_slice = SliceAccessor(self.body.slice) + return get_slice(get_value(object)) + + +class AttributeAccessor(Accessor): + """Accesses an object's attributes through chained attribute""" + + def access(self, object: object) -> object: + attr_name = self.body.attr + access = resolve_accessor(self.body.value) + value = access(object) + return getattr(value, attr_name) + + +def resolve_accessor(body): + return { + ast.Name: NameAccessor, + ast.Subscript: SubsAccessor, + ast.Attribute: AttributeAccessor, + }.get(type(body), Accessor)(body) + + +def parse_accessor(value: str) -> Accessor: + body = parse_body(value) + return resolve_accessor(body) + + +def parse_body(value: str) -> ast.stmt: + bodies = ast.parse(value).body + if len(bodies) > int(True): + raise SyntaxError(f"{repr(value)} exceeds the maximum body count for ast nodes") + + return bodies[0].value diff --git a/sure/original.py b/sure/original.py index b26f8f8..6335b33 100644 --- a/sure/original.py +++ b/sure/original.py @@ -31,6 +31,7 @@ from typing import Union from collections.abc import Iterable +from sure.astuneval import parse_accessor from sure.core import Explanation from sure.core import DeepComparison from sure.core import itemize_length @@ -76,7 +77,7 @@ def __init__(self, src, self.actual = src self._attribute = None - self._eval = None + self.__element_access_expr__ = None self._range = None if all_integers(within_range): if len(within_range) != 2: @@ -351,16 +352,15 @@ def the_attribute(self, attr): return self def in_each(self, attr): - self._eval = attr + self.__element_access_expr__ = attr return self def matches(self, items): msg = '%r[%d].%s should be %r, but is %r' - get_eval = lambda item: eval( - "%s.%s" % ('current', self._eval), {}, {'current': item}, - ) - if self._eval and is_iterable(self.actual): + get_eval = self.__element_access_expr__ and parse_accessor(self.__element_access_expr__) or (lambda x: None) + + if bool(self.__element_access_expr__) and is_iterable(self.actual): if isinstance(items, (str, )): items = [items for x in range(len(items))] else: @@ -381,7 +381,7 @@ def matches(self, items): value = get_eval(item) - error = msg % (self.actual, index, self._eval, other, value) + error = msg % (self.actual, index, self.__element_access_expr__, other, value) if other != value: raise AssertionError(error) else: @@ -409,17 +409,6 @@ def is_empty(self): def are_empty(self): return self.is_empty - def __contains__(self, expectation): - if isinstance(self.actual, dict): - items = self.actual.keys() - - if isinstance(self.actual, Iterable): - items = self.actual - else: - items = dir(self.actual) - - return expectation in items - def contains(self, expectation): if expectation in self.actual: return True diff --git a/tests/unit/test_astuneval.py b/tests/unit/test_astuneval.py new file mode 100644 index 0000000..cbaca15 --- /dev/null +++ b/tests/unit/test_astuneval.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) <2010-2024> Gabriel Falcão +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +import ast +from sure import expects +from sure.astuneval import parse_body +from sure.astuneval import parse_accessor +from sure.astuneval import Accessor, NameAccessor, SubsAccessor, AttributeAccessor + + +def test_parse_body_against_several_kinds(): + expects(parse_body("atomic_bonds[3:7]")).to.be.an(ast.Subscript) + expects(parse_body("children[6]")).to.be.an(ast.Subscript) + expects(parse_body("hippolytus")).to.be.an(ast.Name) + expects(parse_body("zone[4].damage")).to.be.an(ast.Attribute) + + +def test_parse_accessor_name_accessor(): + class Tragedy: + telemachus = "♒️" + + expects(parse_accessor("telemachus")).to.be.a(NameAccessor) + get_character = parse_accessor("telemachus") + expects(get_character(Tragedy)).to.equal('♒️') + + +def test_parse_accessor_subscript_accessor(): + class MonacoGrandPrix1990: + classification = [ + "Ayrton Senna", + "Alain Prost", + "Jean Alesi", + ] + expects(parse_accessor("classification[2]")).to.be.a(SubsAccessor) + get_position = parse_accessor("classification[2]") + expects(get_position(MonacoGrandPrix1990)).to.equal("Jean Alesi") + + +def test_parse_accessor_attr_accessor(): + class FirstResponder: + def __init__(self, bound: str, damage: str): + self.bound = bound + self.damage = damage + + class Incident: + first_responders = [ + FirstResponder("Wyckoff", "unknown"), + FirstResponder("Beth Israel", "unknown"), + FirstResponder("Brooklyn Hospital Center", "unknown"), + FirstResponder("Woodhull", "administered wrong medication"), + ] + expects(parse_accessor("first_responders[3].damage")).to.be.a(AttributeAccessor) + + access_damage = parse_accessor("first_responders[3].damage") + expects(access_damage(Incident)).to.equal("administered wrong medication")