From a131dca844e578bcb9248fc4880a7b6c5872667d Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 1 Sep 2021 00:12:05 +0300 Subject: [PATCH 01/10] ENH: [f2py] Symbolic solver for dimension specifications. --- numpy/f2py/crackfortran.py | 391 +++----- numpy/f2py/symbolic.py | 1258 +++++++++++++++++++++++++ numpy/f2py/tests/test_crackfortran.py | 99 ++ numpy/f2py/tests/test_symbolic.py | 395 ++++++++ 4 files changed, 1882 insertions(+), 261 deletions(-) create mode 100644 numpy/f2py/symbolic.py create mode 100644 numpy/f2py/tests/test_symbolic.py diff --git a/numpy/f2py/crackfortran.py b/numpy/f2py/crackfortran.py index 3ac9b80c8816..cd30d3cec489 100755 --- a/numpy/f2py/crackfortran.py +++ b/numpy/f2py/crackfortran.py @@ -153,7 +153,8 @@ # As the needed functions cannot be determined by static inspection of the # code, it is safest to use import * pending a major refactoring of f2py. from .auxfuncs import * - +from .symbolic import Expr, Language +from . import symbolic f2py_version = __version__.version @@ -875,10 +876,11 @@ def appenddecl(decl, decl2, force=1): decl[k] = decl2[k] elif k == 'note': pass - elif k in ['intent', 'check', 'dimension', 'optional', 'required']: + elif k in ['intent', 'check', 'dimension', 'optional', + 'required', 'depend']: errmess('appenddecl: "%s" not implemented.\n' % k) else: - raise Exception('appenddecl: Unknown variable definition key:' + + raise Exception('appenddecl: Unknown variable definition key: ' + str(k)) return decl @@ -2217,188 +2219,6 @@ def getlincoef(e, xset): # e = a*x+b ; x in xset break return None, None, None -_varname_match = re.compile(r'\A[a-z]\w*\Z').match - - -def getarrlen(dl, args, star='*'): - """ - Parameters - ---------- - dl : sequence of two str objects - dimensions of the array - args : Iterable[str] - symbols used in the expression - star : Any - unused - - Returns - ------- - expr : str - Some numeric expression as a string - arg : Optional[str] - If understood, the argument from `args` present in `expr` - expr2 : Optional[str] - If understood, an expression fragment that should be used as - ``"(%s%s".format(something, expr2)``. - - Examples - -------- - >>> getarrlen(['10*x + 20', '40*x'], {'x'}) - ('30 * x - 19', 'x', '+19)/(30)') - >>> getarrlen(['1', '10*x + 20'], {'x'}) - ('10 * x + 20', 'x', '-20)/(10)') - >>> getarrlen(['10*x + 20', '1'], {'x'}) - ('-10 * x - 18', 'x', '+18)/(-10)') - >>> getarrlen(['20', '1'], {'x'}) - ('-18', None, None) - """ - edl = [] - try: - edl.append(myeval(dl[0], {}, {})) - except Exception: - edl.append(dl[0]) - try: - edl.append(myeval(dl[1], {}, {})) - except Exception: - edl.append(dl[1]) - if isinstance(edl[0], int): - p1 = 1 - edl[0] - if p1 == 0: - d = str(dl[1]) - elif p1 < 0: - d = '%s-%s' % (dl[1], -p1) - else: - d = '%s+%s' % (dl[1], p1) - elif isinstance(edl[1], int): - p1 = 1 + edl[1] - if p1 == 0: - d = '-(%s)' % (dl[0]) - else: - d = '%s-(%s)' % (p1, dl[0]) - else: - d = '%s-(%s)+1' % (dl[1], dl[0]) - try: - return repr(myeval(d, {}, {})), None, None - except Exception: - pass - d1, d2 = getlincoef(dl[0], args), getlincoef(dl[1], args) - if None not in [d1[0], d2[0]]: - if (d1[0], d2[0]) == (0, 0): - return repr(d2[1] - d1[1] + 1), None, None - b = d2[1] - d1[1] + 1 - d1 = (d1[0], 0, d1[2]) - d2 = (d2[0], b, d2[2]) - if d1[0] == 0 and d2[2] in args: - if b < 0: - return '%s * %s - %s' % (d2[0], d2[2], -b), d2[2], '+%s)/(%s)' % (-b, d2[0]) - elif b: - return '%s * %s + %s' % (d2[0], d2[2], b), d2[2], '-%s)/(%s)' % (b, d2[0]) - else: - return '%s * %s' % (d2[0], d2[2]), d2[2], ')/(%s)' % (d2[0]) - if d2[0] == 0 and d1[2] in args: - - if b < 0: - return '%s * %s - %s' % (-d1[0], d1[2], -b), d1[2], '+%s)/(%s)' % (-b, -d1[0]) - elif b: - return '%s * %s + %s' % (-d1[0], d1[2], b), d1[2], '-%s)/(%s)' % (b, -d1[0]) - else: - return '%s * %s' % (-d1[0], d1[2]), d1[2], ')/(%s)' % (-d1[0]) - if d1[2] == d2[2] and d1[2] in args: - a = d2[0] - d1[0] - if not a: - return repr(b), None, None - if b < 0: - return '%s * %s - %s' % (a, d1[2], -b), d2[2], '+%s)/(%s)' % (-b, a) - elif b: - return '%s * %s + %s' % (a, d1[2], b), d2[2], '-%s)/(%s)' % (b, a) - else: - return '%s * %s' % (a, d1[2]), d2[2], ')/(%s)' % (a) - if d1[0] == d2[0] == 1: - c = str(d1[2]) - if c not in args: - if _varname_match(c): - outmess('\tgetarrlen:variable "%s" undefined\n' % (c)) - c = '(%s)' % c - if b == 0: - d = '%s-%s' % (d2[2], c) - elif b < 0: - d = '%s-%s-%s' % (d2[2], c, -b) - else: - d = '%s-%s+%s' % (d2[2], c, b) - elif d1[0] == 0: - c2 = str(d2[2]) - if c2 not in args: - if _varname_match(c2): - outmess('\tgetarrlen:variable "%s" undefined\n' % (c2)) - c2 = '(%s)' % c2 - if d2[0] == 1: - pass - elif d2[0] == -1: - c2 = '-%s' % c2 - else: - c2 = '%s*%s' % (d2[0], c2) - - if b == 0: - d = c2 - elif b < 0: - d = '%s-%s' % (c2, -b) - else: - d = '%s+%s' % (c2, b) - elif d2[0] == 0: - c1 = str(d1[2]) - if c1 not in args: - if _varname_match(c1): - outmess('\tgetarrlen:variable "%s" undefined\n' % (c1)) - c1 = '(%s)' % c1 - if d1[0] == 1: - c1 = '-%s' % c1 - elif d1[0] == -1: - c1 = '+%s' % c1 - elif d1[0] < 0: - c1 = '+%s*%s' % (-d1[0], c1) - else: - c1 = '-%s*%s' % (d1[0], c1) - - if b == 0: - d = c1 - elif b < 0: - d = '%s-%s' % (c1, -b) - else: - d = '%s+%s' % (c1, b) - else: - c1 = str(d1[2]) - if c1 not in args: - if _varname_match(c1): - outmess('\tgetarrlen:variable "%s" undefined\n' % (c1)) - c1 = '(%s)' % c1 - if d1[0] == 1: - c1 = '-%s' % c1 - elif d1[0] == -1: - c1 = '+%s' % c1 - elif d1[0] < 0: - c1 = '+%s*%s' % (-d1[0], c1) - else: - c1 = '-%s*%s' % (d1[0], c1) - - c2 = str(d2[2]) - if c2 not in args: - if _varname_match(c2): - outmess('\tgetarrlen:variable "%s" undefined\n' % (c2)) - c2 = '(%s)' % c2 - if d2[0] == 1: - pass - elif d2[0] == -1: - c2 = '-%s' % c2 - else: - c2 = '%s*%s' % (d2[0], c2) - - if b == 0: - d = '%s%s' % (c2, c1) - elif b < 0: - d = '%s%s-%s' % (c2, c1, -b) - else: - d = '%s%s+%s' % (c2, c1, b) - return d, None, None word_pattern = re.compile(r'\b[a-z][\w$]*\b', re.I) @@ -2683,7 +2503,7 @@ def analyzevars(block): pass vars[n]['kindselector']['kind'] = l - savelindims = {} + dimension_exprs = {} if 'attrspec' in vars[n]: attr = vars[n]['attrspec'] attr.reverse() @@ -2736,18 +2556,18 @@ def analyzevars(block): if dim and 'dimension' not in vars[n]: vars[n]['dimension'] = [] for d in rmbadname([x.strip() for x in markoutercomma(dim).split('@,@')]): - star = '*' - if d == ':': - star = ':' + star = ':' if d == ':' else '*' + # Evaluate `d` with respect to params if d in params: d = str(params[d]) - for p in list(params.keys()): + for p in params: re_1 = re.compile(r'(?P.*?)\b' + p + r'\b(?P.*)', re.I) m = re_1.match(d) while m: d = m.group('before') + \ str(params[p]) + m.group('after') m = re_1.match(d) + if d == star: dl = [star] else: @@ -2755,17 +2575,45 @@ def analyzevars(block): if len(dl) == 2 and '*' in dl: # e.g. dimension(5:*) dl = ['*'] d = '*' - if len(dl) == 1 and not dl[0] == star: + if len(dl) == 1 and dl[0] != star: dl = ['1', dl[0]] if len(dl) == 2: - d, v, di = getarrlen(dl, list(block['vars'].keys())) - if d[:4] == '1 * ': - d = d[4:] - if di and di[-4:] == '/(1)': - di = di[:-4] - if v: - savelindims[d] = v, di + d1, d2 = map(Expr.parse, dl) + dsize = d2 - d1 + 1 + d = dsize.tostring(language=Language.C) + # find variables v that define d as a linear + # function, `d == a * v + b`, and store + # coefficients a and b for further analysis. + solver_and_deps = {} + for v in block['vars']: + s = symbolic.as_symbol(v) + if dsize.contains(s): + try: + a, b = dsize.linear_solve(s) + solve_v = lambda s: (s - b) / a + all_symbols = set(a.symbols()) + all_symbols.update(b.symbols()) + except RuntimeError as msg: + # d is not a linear function of v, + # however, if v can be determined + # from d using other means, + # implement the corresponding + # solve_v function here. + solve_v = None + all_symbols = set(dsize.symbols()) + v_deps = set( + s.data for s in all_symbols + if s.data in vars) + solver_and_deps[v] = solve_v, list(v_deps) + # Note that dsize may contain symbols that are + # not defined in block['vars']. Here we assume + # these correspond to Fortran/C intrinsic + # functions or that are defined by other + # means. We'll let the compiler to validate + # the definiteness of such symbols. + dimension_exprs[d] = solver_and_deps vars[n]['dimension'].append(d) + if 'dimension' in vars[n]: if isintent_c(vars[n]): shape_macro = 'shape' @@ -2789,67 +2637,88 @@ def analyzevars(block): else: errmess( "analyzevars: charselector=%r unhandled." % (d)) + if 'check' not in vars[n] and 'args' in block and n in block['args']: - flag = 'depend' not in vars[n] - if flag: - vars[n]['depend'] = [] - vars[n]['check'] = [] - if 'dimension' in vars[n]: - #/----< no check - i = -1 - ni = len(vars[n]['dimension']) - for d in vars[n]['dimension']: - ddeps = [] # dependencies of 'd' - ad = '' - pd = '' - if d not in vars: - if d in savelindims: - pd, ad = '(', savelindims[d][1] - d = savelindims[d][0] - else: - for r in block['args']: - if r not in vars: - continue - if re.match(r'.*?\b' + r + r'\b', d, re.I): - ddeps.append(r) - if d in vars: - if 'attrspec' in vars[d]: - for aa in vars[d]['attrspec']: - if aa[:6] == 'depend': - ddeps += aa[6:].strip()[1:-1].split(',') - if 'depend' in vars[d]: - ddeps = ddeps + vars[d]['depend'] - i = i + 1 - if d in vars and ('depend' not in vars[d]) \ - and ('=' not in vars[d]) and (d not in vars[n]['depend']) \ - and l_or(isintent_in, isintent_inout, isintent_inplace)(vars[n]): - vars[d]['depend'] = [n] - if ni > 1: - vars[d]['='] = '%s%s(%s,%s)%s' % ( - pd, shape_macro, n, i, ad) - else: - vars[d]['='] = '%slen(%s)%s' % (pd, n, ad) - # /---< no check - if 1 and 'check' not in vars[d]: - if ni > 1: - vars[d]['check'] = ['%s%s(%s,%i)%s==%s' - % (pd, shape_macro, n, i, ad, d)] - else: - vars[d]['check'] = [ - '%slen(%s)%s>=%s' % (pd, n, ad, d)] - if 'attrspec' not in vars[d]: - vars[d]['attrspec'] = ['optional'] - if ('optional' not in vars[d]['attrspec']) and\ - ('required' not in vars[d]['attrspec']): - vars[d]['attrspec'].append('optional') - elif d not in ['*', ':']: - #/----< no check - if flag: - if d in vars: - if n not in ddeps: - vars[n]['depend'].append(d) + # n is an argument that has no checks defined. Here we + # generate some consistency checks for n, and when n is an + # array, generate checks for its dimensions and construct + # initialization expressions. + n_deps = vars[n].get('depend', []) + n_checks = [] + n_is_input = l_or(isintent_in, isintent_inout, isintent_inplace)(vars[n]) + if 'dimension' in vars[n]: # n is array + ni = len(vars[n]['dimension']) # array dimensionality + for i, d in enumerate(vars[n]['dimension']): + coeffs_and_deps = dimension_exprs.get(d) + if coeffs_and_deps is None: + # d is `:` or `*` or a constant expression + pass + elif n_is_input: + # n is an input array argument and its shape + # may define variables used in dimension + # specifications. + for v, (solver, deps) in coeffs_and_deps.items(): + if (v in n_deps + or '=' in vars[v] + or 'depend' in vars[v]): + # Skip a variable that + # - n depends on + # - has user-defined initialization expression + # - has user-defined dependecies + continue + if solver is not None: + # v can be solved from d, hence, we + # make it an optional argument with + # initialization expression: + is_required = False + init = solver(symbolic.as_symbol( + f'shape({n}, {i})')) + init = init.tostring(language=Language.C) + vars[v]['='] = init + # n needs to be initialzed before v. So, + # making v dependent on n and on any + # variables in solver or d. + vars[v]['depend'] = [n] + deps + if 'check' not in vars[v]: + # add check only when no + # user-specified checks exist + vars[v]['check'] = [ + f'shape({n}, {i}) == {d}'] else: - vars[n]['depend'] = vars[n]['depend'] + ddeps + # d is a non-linear function on v, + # hence, v must be a required input + # argument that n will depend on + is_required = True + if 'intent' not in vars[v]: + vars[v]['intent'] = [] + if 'in' not in vars[v]['intent']: + vars[v]['intent'].append('in') + # v needs to be initialized before n + n_deps.append(v) + n_checks.append( + f'shape({n}, {i}) == {d}') + v_attr = vars[v].get('attrspec', []) + if not ('optional' in v_attr + or 'required' in v_attr): + v_attr.append( + 'required' if is_required else 'optional') + if v_attr: + vars[v]['attrspec'] = v_attr + else: + # n is output or hidden argument, hence it + # will depend on all variables in d + n_deps.extend(coeffs_and_deps) + + if coeffs_and_deps is not None: + # extend v dependencies with ones specified in attrspec + for v, (solver, deps) in coeffs_and_deps.items(): + v_deps = vars[v].get('depend', []) + for aa in vars[v].get('attrspec', []): + if aa.startswith('depend'): + aa = ''.join(aa.split()) + v_deps.extend(aa[7:-1].split(',')) + if v_deps: + vars[v]['depend'] = list(set(v_deps)) elif isstring(vars[n]): length = '1' if 'charselector' in vars[n]: @@ -2862,11 +2731,11 @@ def analyzevars(block): params) del vars[n]['charselector']['len'] vars[n]['charselector']['*'] = length + if n_checks: + vars[n]['check'] = n_checks + if n_deps: + vars[n]['depend'] = list(set(n_deps)) - if not vars[n]['check']: - del vars[n]['check'] - if flag and not vars[n]['depend']: - del vars[n]['depend'] if '=' in vars[n]: if 'attrspec' not in vars[n]: vars[n]['attrspec'] = [] diff --git a/numpy/f2py/symbolic.py b/numpy/f2py/symbolic.py new file mode 100644 index 000000000000..d303da0a0d30 --- /dev/null +++ b/numpy/f2py/symbolic.py @@ -0,0 +1,1258 @@ +"""Fortran/C symbolic expressions + +References: +- J3/21-007: Draft Fortran 202x. https://j3-fortran.org/doc/year/21/21-007.pdf +""" + +# To analyze Fortran expressions to solve dimensions specifications, +# for instances, we implement a minimal symbolic engine for parsing +# expressions into a tree of expression instances. As a first +# instance, we care only about arithmetic expressions involving +# integers and operations like addition (+), subtraction (-), +# multiplication (*), division (Fortran / is Python //, Fortran // is +# concatenate), and exponentiation (**). In addition, .pyf files may +# contain C expressions that support here is implemented as well. +# +# TODO: support logical constants (Op.BOOLEAN) +# TODO: support logical operators (.AND., ...) +# TODO: support relational operators (<, >, ..., .LT., ...) +# TODO: support defined operators (.MYOP., ...) +# +__all__ = ['Expr'] + + +import re +import warnings +from enum import Enum +from math import gcd + + +class Language(Enum): + """ + Used as Expr.tostring language argument. + """ + Python = 0 + Fortran = 1 + C = 2 + + +class Op(Enum): + """ + Used as Expr op attribute. + """ + INTEGER = 10 + REAL = 12 + COMPLEX = 15 + STRING = 20 + ARRAY = 30 + SYMBOL = 40 + APPLY = 200 + INDEXING = 210 + CONCAT = 220 + TERMS = 1000 + FACTORS = 2000 + + +class ArithOp(Enum): + """ + Used in Op.APPLY expression to specify the function part. + """ + POS = 1 + NEG = 2 + ADD = 3 + SUB = 4 + MUL = 5 + DIV = 6 + POW = 7 + + +class OpError(Exception): + pass + + +class Precedence(Enum): + """ + Used as Expr.tostring precedence argument. + """ + NONE = 0 + TUPLE = 1 + SUM = 2 + PRODUCT = 3 + POWER = 4 + ATOM = 5 + + +integer_types = (int,) +number_types = (int, float) + + +def _pairs_add(d, k, v): + # Internal utility method for updating terms and factors data. + c = d.get(k) + if c is None: + d[k] = v + else: + c = c + v + if c: + d[k] = c + else: + del d[k] + + +class Expr: + """Represents a Fortran expression as a op-data pair. + + Expr instances are hashable and sortable. + """ + + @staticmethod + def parse(s): + """Parse a Fortran expression to a Expr. + """ + return fromstring(s) + + def __init__(self, op, data): + assert isinstance(op, Op) + + # sanity checks + if op is Op.INTEGER: + # data is a 2-tuple of numeric object and a kind value + # (default is 4) + assert isinstance(data, tuple) and len(data) == 2 + assert isinstance(data[0], int) + assert isinstance(data[1], (int, str)), data + elif op is Op.REAL: + # data is a 2-tuple of numeric object and a kind value + # (default is 4) + assert isinstance(data, tuple) and len(data) == 2 + assert isinstance(data[0], float) + assert isinstance(data[1], (int, str)), data + elif op is Op.COMPLEX: + # data is a 2-tuple of constant expressions + assert isinstance(data, tuple) and len(data) == 2 + elif op is Op.STRING: + # data is a 2-tuple of quoted string and a kind value + # (default is 1) + assert isinstance(data, tuple) and len(data) == 2 + assert (isinstance(data[0], str) + and data[0][::len(data[0])-1] in ('""', "''", '@@')) + assert isinstance(data[1], (int, str)), data + elif op is Op.SYMBOL: + # data is any hashable object + assert hash(data) is not None + elif op in (Op.ARRAY, Op.CONCAT): + # data is a tuple of expressions + assert isinstance(data, tuple) + assert all(isinstance(item, Expr) for item in data), data + elif op in (Op.TERMS, Op.FACTORS): + # data is {:} where dict values + # are nonzero Python integers + assert isinstance(data, dict) + elif op is Op.APPLY: + # data is (, , ) where + # operands are Expr instances + assert isinstance(data, tuple) and len(data) == 3 + # function is any hashable object + assert hash(data[0]) is not None + assert isinstance(data[1], tuple) + assert isinstance(data[2], dict) + elif op is Op.INDEXING: + # data is (, ) + assert isinstance(data, tuple) and len(data) == 2 + # function is any hashable object + assert hash(data[0]) is not None + else: + raise NotImplementedError( + f'unknown op or missing sanity check: {op}') + + self.op = op + self.data = data + + def __eq__(self, other): + return (isinstance(other, Expr) + and self.op is other.op + and self.data == other.data) + + def __hash__(self): + if self.op in (Op.TERMS, Op.FACTORS): + data = tuple(sorted(self.data.items())) + elif self.op is Op.APPLY: + data = self.data[:2] + tuple(sorted(self.data[2].items())) + else: + data = self.data + return hash((self.op, data)) + + def __lt__(self, other): + if isinstance(other, Expr): + if self.op is not other.op: + return self.op.value < other.op.value + if self.op in (Op.TERMS, Op.FACTORS): + return (tuple(sorted(self.data.items())) + < tuple(sorted(other.data.items()))) + if self.op is Op.APPLY: + if self.data[:2] != other.data[:2]: + return self.data[:2] < other.data[:2] + return tuple(sorted(self.data[2].items())) < tuple( + sorted(other.data[2].items())) + return self.data < other.data + return NotImplemented + + def __repr__(self): + return f'{type(self).__name__}({self.op}, {self.data!r})' + + def __str__(self): + return self.tostring() + + def tostring(self, parent_precedence=Precedence.NONE, + language=Language.Fortran): + """Return a string representation of Expr. + """ + if self.op in (Op.INTEGER, Op.REAL): + precedence = (Precedence.SUM if self.data[0] < 0 + else Precedence.ATOM) + r = str(self.data[0]) + (f'_{self.data[1]}' + if self.data[1] != 4 else '') + elif self.op is Op.COMPLEX: + r = ', '.join(item.tostring(Precedence.TUPLE, language=language) + for item in self.data) + r = '(' + r + ')' + precedence = Precedence.ATOM + elif self.op is Op.SYMBOL: + precedence = Precedence.ATOM + r = str(self.data) + elif self.op is Op.STRING: + r = self.data[0] + if self.data[1] != 1: + r = self.data[1] + '_' + r + precedence = Precedence.ATOM + elif self.op is Op.ARRAY: + r = ', '.join(item.tostring(Precedence.TUPLE, language=language) + for item in self.data) + r = '[' + r + ']' + precedence = Precedence.ATOM + elif self.op is Op.TERMS: + terms = [] + for term, coeff in sorted(self.data.items()): + if coeff < 0: + op = ' - ' + coeff = -coeff + else: + op = ' + ' + if coeff == 1: + term = term.tostring(Precedence.SUM, language=language) + else: + if term == as_number(1): + term = str(coeff) + else: + term = f'{coeff} * ' + term.tostring( + Precedence.PRODUCT, language=language) + if terms: + terms.append(op) + elif op == ' - ': + terms.append('-') + terms.append(term) + r = ''.join(terms) or '0' + precedence = Precedence.SUM if terms else Precedence.ATOM + elif self.op is Op.FACTORS: + factors = [] + tail = [] + for base, exp in sorted(self.data.items()): + op = ' * ' + if exp == 1: + factor = base.tostring(Precedence.PRODUCT, + language=language) + elif language is Language.C: + if exp in range(2, 10): + factor = base.tostring(Precedence.PRODUCT, + language=language) + factor = ' * '.join([factor] * exp) + elif exp in range(-10, 0): + factor = base.tostring(Precedence.PRODUCT, + language=language) + tail += [factor] * -exp + continue + else: + factor = base.tostring(Precedence.TUPLE, + language=language) + factor = f'pow({factor}, {exp})' + else: + factor = base.tostring(Precedence.POWER, + language=language) + f' ** {exp}' + if factors: + factors.append(op) + factors.append(factor) + if tail: + if not factors: + factors += ['1'] + factors += ['/', '(', ' * '.join(tail), ')'] + r = ''.join(factors) or '1' + precedence = Precedence.PRODUCT if factors else Precedence.ATOM + elif self.op is Op.APPLY: + name, args, kwargs = self.data + if name is ArithOp.DIV and language is Language.C: + numer, denom = [arg.tostring(Precedence.PRODUCT, + language=language) + for arg in args] + r = f'{numer} / {denom}' + precedence = Precedence.PRODUCT + else: + args = [arg.tostring(Precedence.TUPLE, language=language) + for arg in args] + args += [k + '=' + v.tostring(Precedence.NONE) + for k, v in kwargs.items()] + r = f'{name}({", ".join(args)})' + precedence = Precedence.ATOM + elif self.op is Op.INDEXING: + name = self.data[0] + args = [arg.tostring(Precedence.TUPLE, language=language) + for arg in self.data[1:]] + r = f'{name}[{", ".join(args)}]' + precedence = Precedence.ATOM + elif self.op is Op.CONCAT: + args = [arg.tostring(Precedence.PRODUCT, language=language) + for arg in self.data] + r = " // ".join(args) + precedence = Precedence.PRODUCT + else: + raise NotImplementedError(f'tostring for op {self.op}') + if parent_precedence.value > precedence.value: + # If parent precedence is higher than operand precedence, + # operand will be enclosed in parenthesis. + return '(' + r + ')' + return r + + def __pos__(self): + return self + + def __neg__(self): + return self * -1 + + def __add__(self, other): + other = as_expr(other) + if isinstance(other, Expr): + if self.op is other.op: + if self.op in (Op.INTEGER, Op.REAL): + return as_number( + self.data[0] + other.data[0], + max(self.data[1], other.data[1])) + if self.op is Op.COMPLEX: + r1, i1 = self.data + r2, i2 = other.data + return as_complex(r1 + r2, i1 + i2) + if self.op is Op.TERMS: + r = Expr(self.op, dict(self.data)) + for k, v in other.data.items(): + _pairs_add(r.data, k, v) + return normalize(r) + if self.op is Op.COMPLEX and other.op in (Op.INTEGER, Op.REAL): + return self + as_complex(other) + elif self.op is (Op.INTEGER, Op.REAL) and other.op is Op.COMPLEX: + return as_complex(self) + other + elif self.op is Op.REAL and other.op is Op.INTEGER: + return self + as_real(other, kind=self.data[1]) + elif self.op is Op.INTEGER and other.op is Op.REAL: + return as_real(self, kind=other.data[1]) + other + return as_terms(self) + as_terms(other) + return NotImplemented + + def __radd__(self, other): + if isinstance(other, number_types): + return as_number(other) + self + return NotImplemented + + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + if isinstance(other, number_types): + return as_number(other) - self + return NotImplemented + + def __mul__(self, other): + other = as_expr(other) + if isinstance(other, Expr): + if self.op is other.op: + if self.op in (Op.INTEGER, Op.REAL): + return as_number(self.data[0] * other.data[0], + max(self.data[1], other.data[1])) + elif self.op is Op.COMPLEX: + r1, i1 = self.data + r2, i2 = other.data + return as_complex(r1 * r2 - i1 * i2, r1 * i2 + r2 * i1) + + if self.op is Op.FACTORS: + r = Expr(self.op, dict(self.data)) + for k, v in other.data.items(): + _pairs_add(r.data, k, v) + return normalize(r) + elif self.op is Op.TERMS: + r = Expr(self.op, {}) + for t1, c1 in self.data.items(): + for t2, c2 in other.data.items(): + _pairs_add(r.data, t1 * t2, c1 * c2) + return normalize(r) + + if self.op is Op.COMPLEX and other.op in (Op.INTEGER, Op.REAL): + return self * as_complex(other) + elif other.op is Op.COMPLEX and self.op in (Op.INTEGER, Op.REAL): + return as_complex(self) * other + elif self.op is Op.REAL and other.op is Op.INTEGER: + return self * as_real(other, kind=self.data[1]) + elif self.op is Op.INTEGER and other.op is Op.REAL: + return as_real(self, kind=other.data[1]) * other + + if self.op is Op.TERMS: + return self * as_terms(other) + elif other.op is Op.TERMS: + return as_terms(self) * other + + return as_factors(self) * as_factors(other) + return NotImplemented + + def __rmul__(self, other): + if isinstance(other, number_types): + return as_number(other) * self + return NotImplemented + + def __pow__(self, other): + other = as_expr(other) + if isinstance(other, Expr): + if other.op is Op.INTEGER: + exponent = other.data[0] + # TODO: other kind not used + if exponent == 0: + return as_number(1) + if exponent == 1: + return self + if exponent > 0: + if self.op is Op.FACTORS: + r = Expr(self.op, {}) + for k, v in self.data.items(): + r.data[k] = v * exponent + return normalize(r) + return self * (self ** (exponent - 1)) + elif exponent != -1: + return (self ** (-exponent)) ** -1 + return Expr(Op.FACTORS, {self: exponent}) + return as_apply(ArithOp.POW, self, other) + return NotImplemented + + def __truediv__(self, other): + other = as_expr(other) + if isinstance(other, Expr): + # Fortran / is different from Python /: + # - `/` is a truncate operation for integer operands + return normalize(as_apply(ArithOp.DIV, self, other)) + return NotImplemented + + def __rtruediv__(self, other): + other = as_expr(other) + if isinstance(other, Expr): + return other / self + return NotImplemented + + def __floordiv__(self, other): + other = as_expr(other) + if isinstance(other, Expr): + # Fortran // is different from Python //: + # - `//` is a concatenate operation for string operands + return normalize(Expr(Op.CONCAT, (self, other))) + return NotImplemented + + def __rfloordiv__(self, other): + other = as_expr(other) + if isinstance(other, Expr): + return other // self + return NotImplemented + + def __call__(self, *args, **kwargs): + # In Fortran, parenthesis () are use for both function call as + # well as indexing operations. + # + # TODO: implement a method for deciding when __call__ should + # return an INDEXING expression. + return as_apply(self, *map(as_expr, args), + **dict((k, as_expr(v)) for k, v in kwargs.items())) + + def __getitem__(self, index): + # Provided to support C indexing operations that .pyf files + # may contain. + index = as_expr(index) + if not isinstance(index, tuple): + index = index, + if len(index) > 1: + warnings.warn( + f'C-index should be a single expression but got `{index}`') + return Expr(Op.INDEXING, (self,) + index) + + def substitute(self, symbols_map): + """Recursively substitute symbols with values in symbols map. + + Symbols map is a dictionary of symbol-expression pairs. + """ + if self.op is Op.SYMBOL: + value = symbols_map.get(self) + if value is None: + return self + m = re.match(r'\A(@__f2py_PARENTHESIS_(\w+)_\d+@)\Z', self.data) + if m: + # complement to fromstring method + items, paren = m.groups() + if paren in ['ROUNDDIV', 'SQUARE']: + return as_array(value) + assert paren == 'ROUND', (paren, value) + return value + if self.op in (Op.INTEGER, Op.REAL, Op.STRING): + return self + if self.op in (Op.ARRAY, Op.COMPLEX): + return Expr(self.op, tuple(item.substitute(symbols_map) + for item in self.data)) + if self.op is Op.CONCAT: + return normalize(Expr(self.op, tuple(item.substitute(symbols_map) + for item in self.data))) + if self.op is Op.TERMS: + r = None + for term, coeff in self.data.items(): + if r is None: + r = term.substitute(symbols_map) * coeff + else: + r += term.substitute(symbols_map) * coeff + if r is None: + warnings.warn('substitute: empty TERMS expression interpreted' + ' as int-literal 0') + return as_number(0) + return r + if self.op is Op.FACTORS: + r = None + for base, exponent in self.data.items(): + if r is None: + r = base.substitute(symbols_map) ** exponent + else: + r *= base.substitute(symbols_map) ** exponent + if r is None: + warnings.warn( + 'substitute: empty FACTORS expression interpreted' + ' as int-literal 1') + return as_number(1) + return r + if self.op is Op.APPLY: + target, args, kwargs = self.data + if isinstance(target, Expr): + target = target.substitute(symbols_map) + args = tuple(a.substitute(symbols_map) for a in args) + kwargs = dict((k, v.substitute(symbols_map)) + for k, v in kwargs.items()) + return normalize(Expr(self.op, (target, args, kwargs))) + if self.op is Op.INDEXING: + func = self.data[0] + if isinstance(func, Expr): + func = func.substitute(symbols_map) + args = tuple(a.substitute(symbols_map) for a in self.data[1:]) + return normalize(Expr(self.op, (func,) + args)) + raise NotImplementedError(f'substitute method for {self.op}: {self!r}') + + def traverse(self, visit, *args, **kwargs): + """Traverse expression tree with visit function. + + The visit function is applied to an expression with given args + and kwargs. + + Traverse call returns an expression returned by visit when not + None, otherwise return a new normalized expression with + traverse-visit sub-expressions. + """ + result = visit(self, *args, **kwargs) + if result is not None: + return result + + if self.op in (Op.INTEGER, Op.REAL, Op.STRING, Op.SYMBOL): + return self + elif self.op in (Op.COMPLEX, Op.ARRAY, Op.CONCAT): + return normalize(Expr(self.op, tuple( + item.traverse(visit, *args, **kwargs) + for item in self.data))) + elif self.op in (Op.TERMS, Op.FACTORS): + data = {} + for k, v in self.data.items(): + k = k.traverse(visit, *args, **kwargs) + v = (v.traverse(visit, *args, **kwargs) + if isinstance(v, Expr) else v) + if k in data: + v = data[k] + v + data[k] = v + return normalize(Expr(self.op, data)) + elif self.op is Op.APPLY: + obj = self.data[0] + func = (obj.traverse(visit, *args, **kwargs) + if isinstance(obj, Expr) else obj) + operands = tuple(operand.traverse(visit, *args, **kwargs) + for operand in self.data[1]) + kwoperands = dict((k, v.traverse(visit, *args, **kwargs)) + for k, v in self.data[2].items()) + return normalize(Expr(self.op, (func, operands, kwoperands))) + elif self.op is Op.INDEXING: + obj = self.data[0] + obj = (obj.traverse(visit, *args, **kwargs) + if isinstance(obj, Expr) else obj) + indices = tuple(index.traverse(visit, *args, **kwargs) + for index in self.data[1:]) + return normalize(Expr(self.op, (obj,) + indices)) + raise NotImplementedError(f'traverse method for {self.op}') + + def contains(self, other): + """Check if self contains other. + """ + found = [] + + def visit(expr, found=found): + if found: + return expr + elif expr == other: + found.append(1) + return expr + + self.traverse(visit) + + return len(found) != 0 + + def symbols(self): + """Return a set of symbols contained in self. + """ + found = set() + + def visit(expr, found=found): + if expr.op is Op.SYMBOL: + found.add(expr) + + self.traverse(visit) + + return found + + def polynomial_atoms(self): + """Return a set of expressions used as atoms in polynomial self. + """ + found = set() + + def visit(expr, found=found): + if expr.op is Op.FACTORS: + for b in expr.data: + b.traverse(visit) + return expr + if expr.op in (Op.TERMS, Op.COMPLEX): + return + if expr.op is Op.APPLY and isinstance(expr.data[0], ArithOp): + if expr.data[0] is ArithOp.POW: + expr.data[1][0].traverse(visit) + return expr + return + if expr.op in (Op.INTEGER, Op.REAL): + return expr + + found.add(expr) + + if expr.op in (Op.INDEXING, Op.APPLY): + return expr + + self.traverse(visit) + + return found + + def linear_solve(self, symbol): + """Return a, b such that a * symbol + b == self. + + If self is not linear with respect to symbol, raise RuntimeError. + """ + b = self.substitute({symbol: as_number(0)}) + ax = self - b + a = ax.substitute({symbol: as_number(1)}) + + zero, _ = as_numer_denom(a * symbol - ax) + + if zero != as_number(0): + raise RuntimeError(f'not a {symbol}-linear equation:' + f' {a} * {symbol} + {b} == {self}') + return a, b + + +def normalize(obj): + """Normalize Expr and apply basic evaluation methods. + """ + if not isinstance(obj, Expr): + return obj + + if obj.op is Op.TERMS: + d = {} + for t, c in obj.data.items(): + if c == 0: + continue + if t.op is Op.COMPLEX and c != 1: + t = t * c + c = 1 + if t.op is Op.TERMS: + for t1, c1 in t.data.items(): + _pairs_add(d, t1, c1 * c) + else: + _pairs_add(d, t, c) + if len(d) == 0: + # TODO: deterimine correct kind + return as_number(0) + elif len(d) == 1: + (t, c), = d.items() + if c == 1: + return t + return Expr(Op.TERMS, d) + + if obj.op is Op.FACTORS: + coeff = 1 + d = {} + for b, e in obj.data.items(): + if e == 0: + continue + if b.op is Op.TERMS and isinstance(e, integer_types) and e > 1: + # expand integer powers of sums + b = b * (b ** (e - 1)) + e = 1 + + if b.op in (Op.INTEGER, Op.REAL): + if e == 1: + coeff *= b.data[0] + elif e > 0: + coeff *= b.data[0] ** e + else: + _pairs_add(d, b, e) + elif b.op is Op.FACTORS: + if e > 0 and isinstance(e, integer_types): + for b1, e1 in b.data.items(): + _pairs_add(d, b1, e1 * e) + else: + _pairs_add(d, b, e) + else: + _pairs_add(d, b, e) + if len(d) == 0 or coeff == 0: + # TODO: deterimine correct kind + assert isinstance(coeff, number_types) + return as_number(coeff) + elif len(d) == 1: + (b, e), = d.items() + if e == 1: + t = b + else: + t = Expr(Op.FACTORS, d) + if coeff == 1: + return t + return Expr(Op.TERMS, {t: coeff}) + elif coeff == 1: + return Expr(Op.FACTORS, d) + else: + return Expr(Op.TERMS, {Expr(Op.FACTORS, d): coeff}) + + if obj.op is Op.APPLY and obj.data[0] is ArithOp.DIV: + dividend, divisor = obj.data[1] + t1, c1 = as_term_coeff(dividend) + t2, c2 = as_term_coeff(divisor) + if isinstance(c1, integer_types) and isinstance(c2, integer_types): + g = gcd(c1, c2) + c1, c2 = c1//g, c2//g + else: + c1, c2 = c1/c2, 1 + + if t1.op is Op.APPLY and t1.data[0] is ArithOp.DIV: + numer = t1.data[1][0] * c1 + denom = t1.data[1][1] * t2 * c2 + return as_apply(ArithOp.DIV, numer, denom) + + if t2.op is Op.APPLY and t2.data[0] is ArithOp.DIV: + numer = t2.data[1][1] * t1 * c1 + denom = t2.data[1][0] * c2 + return as_apply(ArithOp.DIV, numer, denom) + + d = dict(as_factors(t1).data) + for b, e in as_factors(t2).data.items(): + _pairs_add(d, b, -e) + numer, denom = {}, {} + for b, e in d.items(): + if e > 0: + numer[b] = e + else: + denom[b] = -e + numer = normalize(Expr(Op.FACTORS, numer)) * c1 + denom = normalize(Expr(Op.FACTORS, denom)) * c2 + + if denom.op in (Op.INTEGER, Op.REAL) and denom.data[0] == 1: + # TODO: denom kind not used + return numer + return as_apply(ArithOp.DIV, numer, denom) + + if obj.op is Op.CONCAT: + lst = [obj.data[0]] + for s in obj.data[1:]: + last = lst[-1] + if ( + last.op is Op.STRING + and s.op is Op.STRING + and last.data[0][0] in '"\'' + and s.data[0][0] == last.data[0][-1] + ): + new_last = as_string(last.data[0][:-1] + s.data[0][1:], + max(last.data[1], s.data[1])) + lst[-1] = new_last + else: + lst.append(s) + if len(lst) == 1: + return lst[0] + return Expr(Op.CONCAT, tuple(lst)) + return obj + + +def as_expr(obj): + """Convert non-Expr objects to Expr objects. + """ + if isinstance(obj, complex): + return as_complex(obj.real, obj.imag) + if isinstance(obj, number_types): + return as_number(obj) + if isinstance(obj, str): + # STRING expression holds string with boundary quotes, hence + # applying repr: + return as_string(repr(obj)) + if isinstance(obj, tuple): + return tuple(map(as_expr, obj)) + return obj + + +def as_symbol(obj): + """Return object as SYMBOL expression (variable or unparsed expression). + """ + return Expr(Op.SYMBOL, obj) + + +def as_number(obj, kind=4): + """Return object as INTEGER or REAL constant. + """ + if isinstance(obj, int): + return Expr(Op.INTEGER, (obj, kind)) + if isinstance(obj, float): + return Expr(Op.REAL, (obj, kind)) + if isinstance(obj, Expr): + if obj.op in (Op.INTEGER, Op.REAL): + return obj + raise OpError(f'cannot convert {obj} to INTEGER or REAL constant') + + +def as_integer(obj, kind=4): + """Return object as INTEGER constant. + """ + if isinstance(obj, int): + return Expr(Op.INTEGER, (obj, kind)) + if isinstance(obj, Expr): + if obj.op is Op.INTEGER: + return obj + raise OpError(f'cannot convert {obj} to INTEGER constant') + + +def as_real(obj, kind=4): + """Return object as REAL constant. + """ + if isinstance(obj, int): + return Expr(Op.REAL, (float(obj), kind)) + if isinstance(obj, float): + return Expr(Op.REAL, (obj, kind)) + if isinstance(obj, Expr): + if obj.op is Op.REAL: + return obj + elif obj.op is Op.INTEGER: + return Expr(Op.REAL, (float(obj.data[0]), kind)) + raise OpError(f'cannot convert {obj} to REAL constant') + + +def as_string(obj, kind=1): + """Return object as STRING expression (string literal constant). + """ + return Expr(Op.STRING, (obj, kind)) + + +def as_array(obj): + """Return object as ARRAY expression (array constant). + """ + if isinstance(obj, Expr): + obj = obj, + return Expr(Op.ARRAY, obj) + + +def as_complex(real, imag=0): + """Return object as COMPLEX expression (complex literal constant). + """ + return Expr(Op.COMPLEX, (as_expr(real), as_expr(imag))) + + +def as_apply(func, *args, **kwargs): + """Return object as APPLY expression (function call, constructor, etc.) + """ + return Expr(Op.APPLY, + (func, tuple(map(as_expr, args)), + dict((k, as_expr(v)) for k, v in kwargs.items()))) + + +def as_terms(obj): + """Return expression as TERMS expression. + """ + if isinstance(obj, Expr): + obj = normalize(obj) + if obj.op is Op.TERMS: + return obj + if obj.op is Op.INTEGER: + return Expr(Op.TERMS, {as_integer(1, obj.data[1]): obj.data[0]}) + if obj.op is Op.REAL: + return Expr(Op.TERMS, {as_real(1, obj.data[1]): obj.data[0]}) + return Expr(Op.TERMS, {obj: 1}) + raise OpError(f'cannot convert {type(obj)} to terms Expr') + + +def as_factors(obj): + """Return expression as FACTORS expression. + """ + if isinstance(obj, Expr): + obj = normalize(obj) + if obj.op is Op.FACTORS: + return obj + if obj.op is Op.TERMS: + if len(obj.data) == 1: + (term, coeff), = obj.data.items() + if coeff == 1: + return Expr(Op.FACTORS, {term: 1}) + return Expr(Op.FACTORS, {term: 1, Expr.number(coeff): 1}) + if ((obj.op is Op.APPLY + and obj.data[0] is ArithOp.DIV + and not obj.data[2])): + return Expr(Op.FACTORS, {obj.data[1][0]: 1, obj.data[1][1]: -1}) + return Expr(Op.FACTORS, {obj: 1}) + raise OpError(f'cannot convert {type(obj)} to terms Expr') + + +def as_term_coeff(obj): + """Return expression as term-coefficient pair. + """ + if isinstance(obj, Expr): + obj = normalize(obj) + if obj.op is Op.INTEGER: + return as_integer(1, obj.data[1]), obj.data[0] + if obj.op is Op.REAL: + return as_real(1, obj.data[1]), obj.data[0] + if obj.op is Op.TERMS: + if len(obj.data) == 1: + (term, coeff), = obj.data.items() + return term, coeff + # TODO: find common divisior of coefficients + if obj.op is Op.APPLY and obj.data[0] is ArithOp.DIV: + t, c = as_term_coeff(obj.data[1][0]) + return as_apply(ArithOp.DIV, t, obj.data[1][1]), c + return obj, 1 + raise OpError(f'cannot convert {type(obj)} to term and coeff') + + +def as_numer_denom(obj): + """Return expression as numer-denom pair. + """ + if isinstance(obj, Expr): + obj = normalize(obj) + if obj.op in (Op.INTEGER, Op.REAL, Op.COMPLEX, Op.SYMBOL, Op.INDEXING): + return obj, as_number(1) + elif obj.op is Op.APPLY: + if obj.data[0] is ArithOp.DIV and not obj.data[2]: + numers, denoms = map(as_numer_denom, obj.data[1]) + return numers[0] * denoms[1], numers[1] * denoms[0] + return obj, as_number(1) + elif obj.op is Op.TERMS: + numers, denoms = [], [] + for term, coeff in obj.data.items(): + n, d = as_numer_denom(term) + n = n * coeff + numers.append(n) + denoms.append(d) + numer, denom = as_number(0), as_number(1) + for i in range(len(numers)): + n = numers[i] + for j in range(len(numers)): + if i != j: + n *= denoms[j] + numer += n + denom *= denoms[i] + if denom.op in (Op.INTEGER, Op.REAL) and denom.data[0] < 0: + numer, denom = -numer, -denom + return numer, denom + elif obj.op is Op.FACTORS: + numer, denom = as_number(1), as_number(1) + for b, e in obj.data.items(): + bnumer, bdenom = as_numer_denom(b) + if e > 0: + numer *= bnumer ** e + denom *= bdenom ** e + elif e < 0: + numer *= bdenom ** (-e) + denom *= bnumer ** (-e) + return numer, denom + raise OpError(f'cannot convert {type(obj)} to numer and denom') + + +def _counter(): + # Used internally to generate unique dummy symbols + counter = 0 + while True: + counter += 1 + yield counter + + +COUNTER = _counter() + + +def replace_quotes(s): + """Replace quoted substrings of input. + + Return a new string and a mapping of replacements. + """ + d = {} + + def repl(m): + kind, value = m.groups()[:2] + if kind: + # remove trailing underscore + kind = kind[:-1] + p = {"'": "SINGLE", '"': "DOUBLE"}[value[0]] + k = f'@__f2py_QUOTES_{p}_{COUNTER.__next__()}@' + d[as_symbol(k)] = as_string(value, kind or 1) + return k + + return (re.sub(r'({kind}_|)({single_quoted}|{double_quoted})'.format( + kind=r'\w[\w\d_]*', + single_quoted=r"('([^'\\]|(\\.))*')", + double_quoted=r'("([^"\\]|(\\.))*")'), + repl, s), d) + + +def replace_parenthesis(s): + """Replace substrings of input that are enclosed in parenthesis. + + Return a new string and a mapping of replacements. + """ + # Find a parenthesis pair that appears first. + + # Fortran deliminator are `(`, `)`, `[`, `]`, `(/', '/)`, `/`. + # We don't handle `/` deliminator because it is not a part of an + # expression. + left, right = None, None + mn_i = len(s) + for left_, right_ in (('(/', '/)'), + '()', + '{}', # to support C literal structs + '[]'): + i = s.find(left_) + if i == -1: + continue + if i < mn_i: + mn_i = i + left, right = left_, right_ + + if left is None: + return s, {} + + i = mn_i + j = s.find(right, i) + + while s.count(left, i + 1, j) != s.count(right, i + 1, j): + j = s.find(right, j + 1) + if j == -1: + raise ValueError(f'Mismatch of {left+right} parenthesis in {s!r}') + + p = {'(': 'ROUND', '[': 'SQUARE', '{': 'CURLY', '(/': 'ROUNDDIV'}[left] + + k = f'@__f2py_PARENTHESIS_{p}_{COUNTER.__next__()}@' + v = s[i+len(left):j] + r, d = replace_parenthesis(s[j+len(right):]) + d[k] = v + return s[:i] + k + r, d + + +def fromstring(s): + """Create an expression from string. + + This is a "lazy" parser, that is, only arithmetic operations are + resolved, non-arithmetic operations are treated as symbols. + """ + # We replace string literal constants that may include parenthesis + # that may confuse the arithmetic expression parser. + s, quotes_map = replace_quotes(s) + + # Apply arithmetic expression parser. + r = _fromstring_worker(s) + + # Restore string constants. + return r.substitute(quotes_map) + + +class _Pair: + # Internal class to represent a pair of expressions + + def __init__(self, left, right): + self.left = left + self.right = right + + def substitute(self, symbols_map): + left, right = self.left, self.right + if isinstance(left, Expr): + left = left.substitute(symbols_map) + if isinstance(right, Expr): + right = right.substitute(symbols_map) + return _Pair(left, right) + + def __repr__(self): + return f'{type(self).__name__}({self.left}, {self.right})' + + +def _fromstring_worker(s, dummy=None): + # Internal method used by fromstring to convert a string of + # arithmetic expressions to an expression tree. + + # Replace subexpression enclosed in parenthesis with dummy symbols + # and recursively parse subexpression into symbol-expression map: + r, raw_symbols_map = replace_parenthesis(s) + r = r.strip() + + symbols_map = dict([(as_symbol(k), _fromstring_worker(v)) + for k, v in raw_symbols_map.items()]) + + if ',' in r: + # comma-separate tuple of expressions + return tuple(_fromstring_worker(r_).substitute(symbols_map) + for r_ in r.split(',')) + + m = re.match(r'\A(\w[\w\d_]*)\s*[=](.*)\Z', r) + if m: + # keyword argument + keyname, value = m.groups() + return _Pair(keyname, + _fromstring_worker(value).substitute(symbols_map)) + + operands = re.split(r'((? 1: + # Expression is an arithmetic sum + result = _fromstring_worker(operands[0] or '0') + for op, operand in zip(operands[1::2], operands[2::2]): + op = op.strip() + if op == '+': + result += _fromstring_worker(operand) + else: + assert op == '-' + result -= _fromstring_worker(operand) + return result.substitute(symbols_map) + + operands = re.split(r'//', r) + if len(operands) > 1: + # Expression is string concatenate operation + return Expr(Op.CONCAT, + tuple(_fromstring_worker(operand).substitute(symbols_map) + for operand in operands)) + + operands = re.split(r'([*]|/)', r.replace('**', '@__f2py_DOUBLE_STAR@')) + operands = [operand.replace('@__f2py_DOUBLE_STAR@', '**') + for operand in operands] + if len(operands) > 1: + # Expression is an arithmetic product + result = _fromstring_worker(operands[0]) + for op, operand in zip(operands[1::2], operands[2::2]): + op = op.strip() + if op == '*': + result *= _fromstring_worker(operand) + else: + assert op == '/' + result /= _fromstring_worker(operand) + return result.substitute(symbols_map) + + operands = list(reversed(re.split(r'[*][*]', r))) + if len(operands) > 1: + # Expression is an arithmetic exponentiation + result = _fromstring_worker(operands[0]) + for operand in operands[1:]: + result = _fromstring_worker(operand) ** result + return result + + m = re.match(r'\A({digit_string})({kind}|)\Z'.format( + digit_string=r'\d+', + kind=r'_(\d+|\w[\w\d_]*)'), r) + if m: + # Expression is int-literal-constant + value, _, kind = m.groups() + if kind and kind.isdigit(): + kind = int(kind) + return as_integer(int(value), kind or 4) + + m = re.match(r'\A({significant}({exponent}|)|\d+{exponent})({kind}|)\Z' + .format( + significant=r'[.]\d+|\d+[.]\d*', + exponent=r'[edED][+-]?\d+', + kind=r'_(\d+|\w[\w\d_]*)'), r) + if m: + # Expression is real-literal-constant + value, _, _, kind = m.groups() + if kind and kind.isdigit(): + kind = int(kind) + value = value.lower() + if 'd' in value: + return as_real(float(value.replace('d', 'e')), kind or 8) + return as_real(float(value), kind or 4) + + m = re.match(r'\A(\w[\w\d_]*)_(@__f2py_QUOTES_(\w+)_\d+@)\Z', r) + if m: + # string-literal-constant with kind parameter specification + kind, name, _ = m.groups() + return as_string(name, kind) + + m = re.match(r'\A(\w[\w\d+]*)\s*(@__f2py_PARENTHESIS_(\w+)_\d+@)\Z', r) + if m: + # apply or indexing expression + name, args, paren = m.groups() + target = as_symbol(name).substitute(symbols_map) + if args in raw_symbols_map: + args = symbols_map[as_symbol(args)] + if not isinstance(args, tuple): + args = args, + if paren == 'ROUND': + # Expression is a function call or a Fortran indexing + # operation + kwargs = dict((a.left, a.right) for a in args + if isinstance(a, _Pair)) + args = tuple(a for a in args if not isinstance(a, _Pair)) + return as_apply(target, *args, **kwargs) + if paren == 'SQUARE': + # Expression is a C indexing operation (e.g. used in + # .pyf files) + return target[args] + # A non-Fortran/C expression? + assert 0, (s, r) + + m = re.match(r'\A(@__f2py_PARENTHESIS_(\w+)_\d+@)\Z', r) + if m: + # array constructor or literal complex constant + items, paren = m.groups() + if items in raw_symbols_map: + items = symbols_map[as_symbol(items)] + if (paren == 'ROUND' and isinstance(items, + tuple) and len(items) == 2): + # Expression is a literal complex constant + return as_complex(items[0], items[1]) + if paren in ['ROUNDDIV', 'SQUARE']: + # Expression is a array constructor + return as_array(items) + assert 0, (r, items, paren) + + m = re.match(r'\A\w[\w\d_]*\Z', r) + if m: + # Fortran standard conforming name + return as_symbol(r) + + m = re.match(r'\A@\w[\w\d_]*@\Z', r) + if m: + # f2py special dummy name + return as_symbol(r) + + warnings.warn(f'fromstring: treating {r!r} as symbol') + return as_symbol(r) diff --git a/numpy/f2py/tests/test_crackfortran.py b/numpy/f2py/tests/test_crackfortran.py index 140f42cbcd43..da7974d1abd3 100644 --- a/numpy/f2py/tests/test_crackfortran.py +++ b/numpy/f2py/tests/test_crackfortran.py @@ -1,3 +1,4 @@ +import pytest import numpy as np from numpy.testing import assert_array_equal, assert_equal from numpy.f2py.crackfortran import markinnerspaces @@ -39,6 +40,7 @@ def test_module(self): class TestPublicPrivate(): + def test_defaultPrivate(self, tmp_path): f_path = tmp_path / "mod.f90" with f_path.open('w') as ff: @@ -165,3 +167,100 @@ def test_ignore_inner_quotes(self): def test_multiple_relevant_spaces(self): assert_equal(markinnerspaces("a 'b c' 'd e'"), "a 'b@_@c' 'd@_@e'") assert_equal(markinnerspaces(r'a "b c" "d e"'), r'a "b@_@c" "d@_@e"') + + +class TestDimSpec(util.F2PyTest): + """This test site tests various expressions that are used as dimension + specifications. + + There exists two usage cases where analyzing dimensions + specifications are important. + + In the first case, the size of output arrays must be defined based + on the inputs to a Fortran function. Because Fortran supports + arbitrary bases for indexing, for instance, `arr(lower:upper)`, + f2py has to evaluate an expression `upper - lower + 1` where + `lower` and `upper` are arbitrary expressions of input parameters. + The evaluation is performed in C, so f2py has to translate Fortran + expressions to valid C expressions (an alternative approach is + that a developer specifies the corresponing C expressions in a + .pyf file). + + In the second case, when user provides an input array with a given + size but some hidden parameters used in dimensions specifications + need to be determined based on the input array size. This is a + harder problem because f2py has to solve the inverse problem: find + a parameter `p` such that `upper(p) - lower(p) + 1` equals to the + size of input array. In the case when this equation cannot be + solved (e.g. because the input array size is wrong), raise an + error before calling the Fortran function (that otherwise would + likely crash Python process when the size of input arrays is + wrong). f2py currently supports this case only when the equation + is linear with respect to unknown parameter. + + """ + + suffix = '.f90' + + code_template = textwrap.dedent(""" + function get_arr_size_{count}(a, n) result (length) + integer, intent(in) :: n + integer, dimension({dimspec}), intent(out) :: a + integer length + length = size(a) + end function + + subroutine get_inv_arr_size_{count}(a, n) + integer :: n + ! the value of n is computed in f2py wrapper + !f2py intent(out) n + integer, dimension({dimspec}), intent(in) :: a + if (a({first}).gt.0) then + print*, "a=", a + endif + end subroutine + """) + + linear_dimspecs = ['n', '2*n', '2:n', 'n/2', '5 - n/2', '3*n:20', + 'n*(n+1):n*(n+5)'] + nonlinear_dimspecs = ['2*n:3*n*n+2*n'] + all_dimspecs = linear_dimspecs + nonlinear_dimspecs + + code = '' + for count, dimspec in enumerate(all_dimspecs): + code += code_template.format( + count=count, dimspec=dimspec, + first=dimspec.split(':')[0] if ':' in dimspec else '1') + + @pytest.mark.parametrize('dimspec', all_dimspecs) + def test_array_size(self, dimspec): + + count = self.all_dimspecs.index(dimspec) + get_arr_size = getattr(self.module, f'get_arr_size_{count}') + + for n in [1, 2, 3, 4, 5]: + sz, a = get_arr_size(n) + assert len(a) == sz + + @pytest.mark.parametrize('dimspec', all_dimspecs) + def test_inv_array_size(self, dimspec): + + count = self.all_dimspecs.index(dimspec) + get_arr_size = getattr(self.module, f'get_arr_size_{count}') + get_inv_arr_size = getattr(self.module, f'get_inv_arr_size_{count}') + + for n in [1, 2, 3, 4, 5]: + sz, a = get_arr_size(n) + if dimspec in self.nonlinear_dimspecs: + # one must specify n as input, the call we'll ensure + # that a and n are compatible: + n1 = get_inv_arr_size(a, n) + else: + # in case of linear dependence, n can be determined + # from the shape of a: + n1 = get_inv_arr_size(a) + # n1 may be different from n (for instance, when `a` size + # is a function of some `n` fraction) but it must produce + # the same sized array + sz1, _ = get_arr_size(n1) + assert sz == sz1, (n, n1, sz, sz1) diff --git a/numpy/f2py/tests/test_symbolic.py b/numpy/f2py/tests/test_symbolic.py new file mode 100644 index 000000000000..27244c995636 --- /dev/null +++ b/numpy/f2py/tests/test_symbolic.py @@ -0,0 +1,395 @@ +from numpy.testing import assert_raises +from numpy.f2py.symbolic import ( + Expr, Op, ArithOp, Language, + as_symbol, as_number, as_string, as_array, as_complex, + as_terms, as_factors, fromstring, replace_quotes, as_expr, as_apply, + as_numer_denom + ) +from . import util + + +class TestSymbolic(util.F2PyTest): + + def test_replace_quotes(self): + + def worker(s): + r, d = replace_quotes(s) + assert '"' not in r, r + assert "'" not in r, r + # undo replacements: + s1 = r + for k, v in d.items(): + s1 = s1.replace(k.data, v.data[0]) + assert s1 == s + + worker('"1234" // "ABCD"') + worker('"1234" // \'ABCD\'') + worker('"1\\"2\'AB\'34"') + worker("'1\\'2\"AB\"34'") + + def test_sanity(self): + x = as_symbol('x') + y = as_symbol('y') + assert x.op == Op.SYMBOL + assert repr(x) == "Expr(Op.SYMBOL, 'x')" + assert x == x + assert x != y + assert hash(x) is not None + + n = as_number(123) + m = as_number(456) + assert n.op == Op.INTEGER + assert repr(n) == "Expr(Op.INTEGER, (123, 4))" + assert n == n + assert n != m + assert hash(n) is not None + + fn = as_number(12.3) + fm = as_number(45.6) + assert fn.op == Op.REAL + assert repr(fn) == "Expr(Op.REAL, (12.3, 4))" + assert fn == fn + assert fn != fm + assert hash(fn) is not None + + c = as_complex(1, 2) + c2 = as_complex(3, 4) + assert c.op == Op.COMPLEX + assert repr(c) == ("Expr(Op.COMPLEX, (Expr(Op.INTEGER, (1, 4))," + " Expr(Op.INTEGER, (2, 4))))") + assert c == c + assert c != c2 + assert hash(c) is not None + + s = as_string("'123'") + s2 = as_string('"ABC"') + assert s.op == Op.STRING + assert repr(s) == "Expr(Op.STRING, (\"'123'\", 1))", repr(s) + assert s == s + assert s != s2 + + a = as_array((n, m)) + b = as_array((n,)) + assert a.op == Op.ARRAY + assert repr(a) == ("Expr(Op.ARRAY, (Expr(Op.INTEGER, (123, 4))," + " Expr(Op.INTEGER, (456, 4))))") + assert a == a + assert a != b + + t = as_terms(x) + u = as_terms(y) + assert t.op == Op.TERMS + assert repr(t) == "Expr(Op.TERMS, {Expr(Op.SYMBOL, 'x'): 1})" + assert t == t + assert t != u + assert hash(t) is not None + + v = as_factors(x) + w = as_factors(y) + assert v.op == Op.FACTORS + assert repr(v) == "Expr(Op.FACTORS, {Expr(Op.SYMBOL, 'x'): 1})" + assert v == v + assert w != v + assert hash(v) is not None + + def test_tostring_fortran(self): + x = as_symbol('x') + y = as_symbol('y') + n = as_number(123) + m = as_number(456) + a = as_array((n, m)) + c = as_complex(n, m) + + assert str(x) == 'x' + assert str(n) == '123' + assert str(a) == '[123, 456]' + assert str(c) == '(123, 456)' + + assert str(Expr(Op.TERMS, {x: 1})) == 'x' + assert str(Expr(Op.TERMS, {x: 2})) == '2 * x' + assert str(Expr(Op.TERMS, {x: -1})) == '-x' + assert str(Expr(Op.TERMS, {x: -2})) == '-2 * x' + assert str(Expr(Op.TERMS, {x: 1, y: 1})) == 'x + y' + assert str(Expr(Op.TERMS, {x: -1, y: -1})) == '-x - y' + assert str(Expr(Op.TERMS, {x: 2, y: 3})) == '2 * x + 3 * y' + assert str(Expr(Op.TERMS, {x: -2, y: 3})) == '-2 * x + 3 * y' + assert str(Expr(Op.TERMS, {x: 2, y: -3})) == '2 * x - 3 * y' + + assert str(Expr(Op.FACTORS, {x: 1})) == 'x' + assert str(Expr(Op.FACTORS, {x: 2})) == 'x ** 2' + assert str(Expr(Op.FACTORS, {x: -1})) == 'x ** -1' + assert str(Expr(Op.FACTORS, {x: -2})) == 'x ** -2' + assert str(Expr(Op.FACTORS, {x: 1, y: 1})) == 'x * y' + assert str(Expr(Op.FACTORS, {x: 2, y: 3})) == 'x ** 2 * y ** 3' + + v = Expr(Op.FACTORS, {x: 2, Expr(Op.TERMS, {x: 1, y: 1}): 3}) + assert str(v) == 'x ** 2 * (x + y) ** 3', str(v) + v = Expr(Op.FACTORS, {x: 2, Expr(Op.FACTORS, {x: 1, y: 1}): 3}) + assert str(v) == 'x ** 2 * (x * y) ** 3', str(v) + + assert str(Expr(Op.APPLY, ('f', (), {}))) == 'f()' + assert str(Expr(Op.APPLY, ('f', (x,), {}))) == 'f(x)' + assert str(Expr(Op.APPLY, ('f', (x, y), {}))) == 'f(x, y)' + assert str(Expr(Op.INDEXING, ('f', x))) == 'f[x]' + + def test_tostring_c(self): + language = Language.C + x = as_symbol('x') + y = as_symbol('y') + n = as_number(123) + + assert Expr(Op.FACTORS, {x: 2}).tostring(language=language) == 'x * x' + assert Expr(Op.FACTORS, {x + y: 2}).tostring( + language=language) == '(x + y) * (x + y)' + assert Expr(Op.FACTORS, {x: 12}).tostring( + language=language) == 'pow(x, 12)' + + assert as_apply(ArithOp.DIV, x, y).tostring( + language=language) == 'x / y' + assert as_apply(ArithOp.DIV, x, x + y).tostring( + language=language) == 'x / (x + y)' + assert as_apply(ArithOp.DIV, x - y, x + y).tostring( + language=language) == '(x - y) / (x + y)' + assert (x + (x - y) / (x + y) + n).tostring( + language=language) == '123 + x + (x - y) / (x + y)' + + def test_operations(self): + x = as_symbol('x') + y = as_symbol('y') + z = as_symbol('z') + + assert x + x == Expr(Op.TERMS, {x: 2}) + assert x - x == Expr(Op.INTEGER, (0, 4)) + assert x + y == Expr(Op.TERMS, {x: 1, y: 1}) + assert x - y == Expr(Op.TERMS, {x: 1, y: -1}) + assert x * x == Expr(Op.FACTORS, {x: 2}) + assert x * y == Expr(Op.FACTORS, {x: 1, y: 1}) + + assert +x == x + assert -x == Expr(Op.TERMS, {x: -1}), repr(-x) + assert 2 * x == Expr(Op.TERMS, {x: 2}) + assert 2 + x == Expr(Op.TERMS, {x: 1, as_number(1): 2}) + assert 2 * x + 3 * y == Expr(Op.TERMS, {x: 2, y: 3}) + assert (x + y) * 2 == Expr(Op.TERMS, {x: 2, y: 2}) + + assert x ** 2 == Expr(Op.FACTORS, {x: 2}) + assert (x + y) ** 2 == Expr(Op.TERMS, + {Expr(Op.FACTORS, {x: 2}): 1, + Expr(Op.FACTORS, {y: 2}): 1, + Expr(Op.FACTORS, {x: 1, y: 1}): 2}) + assert (x + y) * x == x ** 2 + x * y + assert (x + y) ** 2 == x ** 2 + 2 * x * y + y ** 2 + assert (x + y) ** 2 + (x - y) ** 2 == 2 * x ** 2 + 2 * y ** 2 + assert (x + y) * z == x * z + y * z + assert z * (x + y) == x * z + y * z + + assert (x / 2) == as_apply(ArithOp.DIV, x, as_number(2)) + assert (2 * x / 2) == x + assert (3 * x / 2) == as_apply(ArithOp.DIV, 3*x, as_number(2)) + assert (4 * x / 2) == 2 * x + assert (5 * x / 2) == as_apply(ArithOp.DIV, 5*x, as_number(2)) + assert (6 * x / 2) == 3 * x + assert ((3*5) * x / 6) == as_apply(ArithOp.DIV, 5*x, as_number(2)) + assert (30*x**2*y**4 / (24*x**3*y**3)) == as_apply(ArithOp.DIV, + 5*y, 4*x) + assert ((15 * x / 6) / 5) == as_apply( + ArithOp.DIV, x, as_number(2)), ((15 * x / 6) / 5) + assert (x / (5 / x)) == as_apply(ArithOp.DIV, x**2, as_number(5)) + + assert (x / 2.0) == Expr(Op.TERMS, {x: 0.5}) + + s = as_string('"ABC"') + t = as_string('"123"') + + assert s // t == Expr(Op.STRING, ('"ABC123"', 1)) + assert s // x == Expr(Op.CONCAT, (s, x)) + assert x // s == Expr(Op.CONCAT, (x, s)) + + c = as_complex(1., 2.) + assert -c == as_complex(-1., -2.) + assert c + c == as_expr((1+2j)*2) + assert c * c == as_expr((1+2j)**2) + + def test_substitute(self): + x = as_symbol('x') + y = as_symbol('y') + z = as_symbol('z') + a = as_array((x, y)) + + assert x.substitute({x: y}) == y + assert (x + y).substitute({x: z}) == y + z + assert (x * y).substitute({x: z}) == y * z + assert (x ** 4).substitute({x: z}) == z ** 4 + assert (x / y).substitute({x: z}) == z / y + assert x.substitute({x: y + z}) == y + z + assert a.substitute({x: y + z}) == as_array((y + z, y)) + + def test_fromstring(self): + + x = as_symbol('x') + y = as_symbol('y') + z = as_symbol('z') + f = as_symbol('f') + s = as_string('"ABC"') + t = as_string('"123"') + a = as_array((x, y)) + + assert fromstring('x') == x + assert fromstring('+ x') == x + assert fromstring('- x') == -x + assert fromstring('x + y') == x + y + assert fromstring('x + 1') == x + 1 + assert fromstring('x * y') == x * y + assert fromstring('x * 2') == x * 2 + assert fromstring('x / y') == x / y + assert fromstring('x ** 2') == x ** 2 + assert fromstring('x ** 2 ** 3') == x ** 2 ** 3 + assert fromstring('(x + y) * z') == (x + y) * z + + assert fromstring('f(x)') == f(x) + assert fromstring('f(x,y)') == f(x, y) + assert fromstring('f[x]') == f[x] + + assert fromstring('"ABC"') == s + assert fromstring('"ABC" // "123" ') == s // t + assert fromstring('f("ABC")') == f(s) + assert fromstring('MYSTRKIND_"ABC"') == as_string('"ABC"', 'MYSTRKIND') + + assert fromstring('(/x, y/)') == a, fromstring('(/x, y/)') + assert fromstring('f((/x, y/))') == f(a) + assert fromstring('(/(x+y)*z/)') == as_array(((x+y)*z,)) + + assert fromstring('123') == as_number(123) + assert fromstring('123_2') == as_number(123, 2) + assert fromstring('123_myintkind') == as_number(123, 'myintkind') + + assert fromstring('123.0') == as_number(123.0, 4) + assert fromstring('123.0_4') == as_number(123.0, 4) + assert fromstring('123.0_8') == as_number(123.0, 8) + assert fromstring('123.0e0') == as_number(123.0, 4) + assert fromstring('123.0d0') == as_number(123.0, 8) + assert fromstring('123d0') == as_number(123.0, 8) + assert fromstring('123e-0') == as_number(123.0, 4) + assert fromstring('123d+0') == as_number(123.0, 8) + assert fromstring('123.0_myrealkind') == as_number(123.0, 'myrealkind') + assert fromstring('3E4') == as_number(30000.0, 4) + + assert fromstring('(1, 2)') == as_complex(1, 2) + assert fromstring('(1e2, PI)') == as_complex( + as_number(100.0), as_symbol('PI')) + + assert fromstring('[1, 2]') == as_array((as_number(1), as_number(2))) + + assert fromstring('POINT(x, y=1)') == as_apply( + as_symbol('POINT'), x, y=as_number(1)) + assert (fromstring('PERSON(name="John", age=50, shape=(/34, 23/))') + == as_apply(as_symbol('PERSON'), + name=as_string('"John"'), + age=as_number(50), + shape=as_array((as_number(34), as_number(23))))) + + def test_traverse(self): + x = as_symbol('x') + y = as_symbol('y') + z = as_symbol('z') + f = as_symbol('f') + + # Use traverse to substitute a symbol + def replace_visit(s, r=z): + if s == x: + return r + + assert x.traverse(replace_visit) == z + assert y.traverse(replace_visit) == y + assert z.traverse(replace_visit) == z + assert (f(y)).traverse(replace_visit) == f(y) + assert (f(x)).traverse(replace_visit) == f(z) + assert (f[y]).traverse(replace_visit) == f[y] + assert (f[z]).traverse(replace_visit) == f[z] + assert (x + y + z).traverse(replace_visit) == (2 * z + y) + assert (x + f(y, x - z)).traverse( + replace_visit) == (z + f(y, as_number(0))) + + # Use traverse to collect symbols, method 1 + function_symbols = set() + symbols = set() + + def collect_symbols(s): + if s.op is Op.APPLY: + oper = s.data[0] + function_symbols.add(oper) + if oper in symbols: + symbols.remove(oper) + elif s.op is Op.SYMBOL and s not in function_symbols: + symbols.add(s) + + (x + f(y, x - z)).traverse(collect_symbols) + assert function_symbols == {f} + assert symbols == {x, y, z} + + # Use traverse to collect symbols, method 2 + def collect_symbols2(expr, symbols): + if expr.op is Op.SYMBOL: + symbols.add(expr) + + symbols = set() + (x + f(y, x - z)).traverse(collect_symbols2, symbols) + assert symbols == {x, y, z, f} + + # Use traverse to partially collect symbols + def collect_symbols3(expr, symbols): + if expr.op is Op.APPLY: + # skip traversing function calls + return expr + if expr.op is Op.SYMBOL: + symbols.add(expr) + + symbols = set() + (x + f(y, x - z)).traverse(collect_symbols3, symbols) + assert symbols == {x} + + def test_linear_solve(self): + x = as_symbol('x') + y = as_symbol('y') + z = as_symbol('z') + + assert x.linear_solve(x) == (as_number(1), as_number(0)) + assert (x+1).linear_solve(x) == (as_number(1), as_number(1)) + assert (2*x).linear_solve(x) == (as_number(2), as_number(0)) + assert (2*x+3).linear_solve(x) == (as_number(2), as_number(3)) + assert as_number(3).linear_solve(x) == (as_number(0), as_number(3)) + assert y.linear_solve(x) == (as_number(0), y) + assert (y*z).linear_solve(x) == (as_number(0), y * z) + + assert (x+y).linear_solve(x) == (as_number(1), y) + assert (z*x+y).linear_solve(x) == (z, y) + assert ((z+y)*x+y).linear_solve(x) == (z + y, y) + assert (z*y*x+y).linear_solve(x) == (z * y, y) + + assert_raises(RuntimeError, lambda: (x*x).linear_solve(x)) + + def test_as_numer_denom(self): + x = as_symbol('x') + y = as_symbol('y') + n = as_number(123) + + assert as_numer_denom(x) == (x, as_number(1)) + assert as_numer_denom(x / n) == (x, n) + assert as_numer_denom(n / x) == (n, x) + assert as_numer_denom(x / y) == (x, y) + assert as_numer_denom(x * y) == (x * y, as_number(1)) + assert as_numer_denom(n + x / y) == (x + n * y, y) + assert as_numer_denom(n + x / (y - x / n)) == (y * n ** 2, y * n - x) + + def test_polynomial_atoms(self): + x = as_symbol('x') + y = as_symbol('y') + n = as_number(123) + + assert x.polynomial_atoms() == {x} + assert n.polynomial_atoms() == set() + assert (y[x]).polynomial_atoms() == {y[x]} + assert (y(x)).polynomial_atoms() == {y(x)} + assert (y(x) + x).polynomial_atoms() == {y(x), x} + assert (y(x) * x[y]).polynomial_atoms() == {y(x), x[y]} + assert (y(x) ** x).polynomial_atoms() == {y(x)} From d721e67fb3dc732e944493c47fbbcd22d19481dc Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 1 Sep 2021 00:31:32 +0300 Subject: [PATCH 02/10] Fix lint --- numpy/f2py/crackfortran.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/numpy/f2py/crackfortran.py b/numpy/f2py/crackfortran.py index cd30d3cec489..b613845a328f 100755 --- a/numpy/f2py/crackfortran.py +++ b/numpy/f2py/crackfortran.py @@ -2645,7 +2645,8 @@ def analyzevars(block): # initialization expressions. n_deps = vars[n].get('depend', []) n_checks = [] - n_is_input = l_or(isintent_in, isintent_inout, isintent_inplace)(vars[n]) + n_is_input = l_or(isintent_in, isintent_inout, + isintent_inplace)(vars[n]) if 'dimension' in vars[n]: # n is array ni = len(vars[n]['dimension']) # array dimensionality for i, d in enumerate(vars[n]['dimension']): @@ -2658,9 +2659,9 @@ def analyzevars(block): # may define variables used in dimension # specifications. for v, (solver, deps) in coeffs_and_deps.items(): - if (v in n_deps - or '=' in vars[v] - or 'depend' in vars[v]): + if ((v in n_deps + or '=' in vars[v] + or 'depend' in vars[v])): # Skip a variable that # - n depends on # - has user-defined initialization expression From 96e63aaf5c2720c4adc6c4440f2732ee6a73b260 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 1 Sep 2021 11:19:06 +0300 Subject: [PATCH 03/10] Add new module numpy.f2py.symbolic --- numpy/f2py/crackfortran.py | 8 ++++---- numpy/tests/test_public_api.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/numpy/f2py/crackfortran.py b/numpy/f2py/crackfortran.py index b613845a328f..b675c936dbf1 100755 --- a/numpy/f2py/crackfortran.py +++ b/numpy/f2py/crackfortran.py @@ -153,7 +153,6 @@ # As the needed functions cannot be determined by static inspection of the # code, it is safest to use import * pending a major refactoring of f2py. from .auxfuncs import * -from .symbolic import Expr, Language from . import symbolic f2py_version = __version__.version @@ -2578,9 +2577,9 @@ def analyzevars(block): if len(dl) == 1 and dl[0] != star: dl = ['1', dl[0]] if len(dl) == 2: - d1, d2 = map(Expr.parse, dl) + d1, d2 = map(symbolic.Expr.parse, dl) dsize = d2 - d1 + 1 - d = dsize.tostring(language=Language.C) + d = dsize.tostring(language=symbolic.Language.C) # find variables v that define d as a linear # function, `d == a * v + b`, and store # coefficients a and b for further analysis. @@ -2674,7 +2673,8 @@ def analyzevars(block): is_required = False init = solver(symbolic.as_symbol( f'shape({n}, {i})')) - init = init.tostring(language=Language.C) + init = init.tostring( + language=symbolic.Language.C) vars[v]['='] = init # n needs to be initialzed before v. So, # making v dependent on n and on any diff --git a/numpy/tests/test_public_api.py b/numpy/tests/test_public_api.py index 5b15785001c9..73a93f4897d9 100644 --- a/numpy/tests/test_public_api.py +++ b/numpy/tests/test_public_api.py @@ -253,6 +253,7 @@ def test_NPY_NO_EXPORT(): "f2py.f90mod_rules", "f2py.func2subr", "f2py.rules", + "f2py.symbolic", "f2py.use_rules", "fft.helper", "lib.arraypad", From f3fd6273fee44d1cc1aba502cf56aaea4b59d22c Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 1 Sep 2021 11:43:05 +0300 Subject: [PATCH 04/10] Add stacklevel to warn call. --- numpy/f2py/symbolic.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/numpy/f2py/symbolic.py b/numpy/f2py/symbolic.py index d303da0a0d30..b71a87a76589 100644 --- a/numpy/f2py/symbolic.py +++ b/numpy/f2py/symbolic.py @@ -482,7 +482,9 @@ def __getitem__(self, index): index = index, if len(index) > 1: warnings.warn( - f'C-index should be a single expression but got `{index}`') + f'C-index should be a single expression but got `{index}`', + category=warnings.SyntaxWarning, + stacklevel=1) return Expr(Op.INDEXING, (self,) + index) def substitute(self, symbols_map): From 04585856537ccf802d881943235205d95b59527e Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 1 Sep 2021 12:03:36 +0300 Subject: [PATCH 05/10] Handle more Expr warnings. --- numpy/f2py/symbolic.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/numpy/f2py/symbolic.py b/numpy/f2py/symbolic.py index b71a87a76589..54b5721c0e34 100644 --- a/numpy/f2py/symbolic.py +++ b/numpy/f2py/symbolic.py @@ -99,6 +99,14 @@ def _pairs_add(d, k, v): del d[k] +class ExprWarning(warnings.UserWarning): + pass + + +def warn(message): + warnings.warn(message, ExprWarning, stacklevel=2) + + class Expr: """Represents a Fortran expression as a op-data pair. @@ -481,10 +489,7 @@ def __getitem__(self, index): if not isinstance(index, tuple): index = index, if len(index) > 1: - warnings.warn( - f'C-index should be a single expression but got `{index}`', - category=warnings.SyntaxWarning, - stacklevel=1) + warn(f'C-index should be a single expression but got `{index}`') return Expr(Op.INDEXING, (self,) + index) def substitute(self, symbols_map): @@ -520,8 +525,8 @@ def substitute(self, symbols_map): else: r += term.substitute(symbols_map) * coeff if r is None: - warnings.warn('substitute: empty TERMS expression interpreted' - ' as int-literal 0') + warn('substitute: empty TERMS expression interpreted as' + ' int-literal 0') return as_number(0) return r if self.op is Op.FACTORS: @@ -532,9 +537,8 @@ def substitute(self, symbols_map): else: r *= base.substitute(symbols_map) ** exponent if r is None: - warnings.warn( - 'substitute: empty FACTORS expression interpreted' - ' as int-literal 1') + warn('substitute: empty FACTORS expression interpreted' + ' as int-literal 1') return as_number(1) return r if self.op is Op.APPLY: @@ -1256,5 +1260,5 @@ def _fromstring_worker(s, dummy=None): # f2py special dummy name return as_symbol(r) - warnings.warn(f'fromstring: treating {r!r} as symbol') + warn(f'fromstring: treating {r!r} as symbol') return as_symbol(r) From 7b229a277a2a1ed8e3d49d361acbc0d4918f8bf5 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 1 Sep 2021 12:14:09 +0300 Subject: [PATCH 06/10] Fix a mistake --- numpy/f2py/symbolic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpy/f2py/symbolic.py b/numpy/f2py/symbolic.py index 54b5721c0e34..5a67554894e4 100644 --- a/numpy/f2py/symbolic.py +++ b/numpy/f2py/symbolic.py @@ -99,7 +99,7 @@ def _pairs_add(d, k, v): del d[k] -class ExprWarning(warnings.UserWarning): +class ExprWarning(UserWarning): pass From 26752c3ad454f72a94cc7db32ccba43a76a85a5e Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 1 Sep 2021 12:32:23 +0300 Subject: [PATCH 07/10] Workaround numpy.tests.test_warnings bug of treating local warn function as warnings.warn --- numpy/f2py/symbolic.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/numpy/f2py/symbolic.py b/numpy/f2py/symbolic.py index 5a67554894e4..f0b1da288a75 100644 --- a/numpy/f2py/symbolic.py +++ b/numpy/f2py/symbolic.py @@ -103,7 +103,7 @@ class ExprWarning(UserWarning): pass -def warn(message): +def ewarn(message): warnings.warn(message, ExprWarning, stacklevel=2) @@ -489,7 +489,7 @@ def __getitem__(self, index): if not isinstance(index, tuple): index = index, if len(index) > 1: - warn(f'C-index should be a single expression but got `{index}`') + ewarn(f'C-index should be a single expression but got `{index}`') return Expr(Op.INDEXING, (self,) + index) def substitute(self, symbols_map): @@ -525,8 +525,8 @@ def substitute(self, symbols_map): else: r += term.substitute(symbols_map) * coeff if r is None: - warn('substitute: empty TERMS expression interpreted as' - ' int-literal 0') + ewarn('substitute: empty TERMS expression interpreted as' + ' int-literal 0') return as_number(0) return r if self.op is Op.FACTORS: @@ -537,8 +537,8 @@ def substitute(self, symbols_map): else: r *= base.substitute(symbols_map) ** exponent if r is None: - warn('substitute: empty FACTORS expression interpreted' - ' as int-literal 1') + ewarn('substitute: empty FACTORS expression interpreted' + ' as int-literal 1') return as_number(1) return r if self.op is Op.APPLY: @@ -1260,5 +1260,5 @@ def _fromstring_worker(s, dummy=None): # f2py special dummy name return as_symbol(r) - warn(f'fromstring: treating {r!r} as symbol') + ewarn(f'fromstring: treating {r!r} as symbol') return as_symbol(r) From d880dce560fd6369d7418ba83c03cc5947993103 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 1 Sep 2021 22:54:12 +0300 Subject: [PATCH 08/10] Fix bugs and warnings from LGTM report --- numpy/f2py/crackfortran.py | 9 ++------- numpy/f2py/symbolic.py | 8 +++++++- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/numpy/f2py/crackfortran.py b/numpy/f2py/crackfortran.py index b675c936dbf1..8cb1f73d2f9b 100755 --- a/numpy/f2py/crackfortran.py +++ b/numpy/f2py/crackfortran.py @@ -2415,7 +2415,8 @@ def _eval_scalar(value, params): if _is_kind_number(value): value = value.split('_')[0] try: - value = str(eval(value, {}, params)) + value = eval(value, {}, params) + value = (repr if isinstance(value, str) else str)(value) except (NameError, SyntaxError, TypeError): return value except Exception as msg: @@ -2614,10 +2615,6 @@ def analyzevars(block): vars[n]['dimension'].append(d) if 'dimension' in vars[n]: - if isintent_c(vars[n]): - shape_macro = 'shape' - else: - shape_macro = 'shape' # 'fshape' if isstringarray(vars[n]): if 'charselector' in vars[n]: d = vars[n]['charselector'] @@ -2647,7 +2644,6 @@ def analyzevars(block): n_is_input = l_or(isintent_in, isintent_inout, isintent_inplace)(vars[n]) if 'dimension' in vars[n]: # n is array - ni = len(vars[n]['dimension']) # array dimensionality for i, d in enumerate(vars[n]['dimension']): coeffs_and_deps = dimension_exprs.get(d) if coeffs_and_deps is None: @@ -2721,7 +2717,6 @@ def analyzevars(block): if v_deps: vars[v]['depend'] = list(set(v_deps)) elif isstring(vars[n]): - length = '1' if 'charselector' in vars[n]: if '*' in vars[n]['charselector']: length = _eval_length(vars[n]['charselector']['*'], diff --git a/numpy/f2py/symbolic.py b/numpy/f2py/symbolic.py index f0b1da288a75..56faa309c294 100644 --- a/numpy/f2py/symbolic.py +++ b/numpy/f2py/symbolic.py @@ -205,6 +205,12 @@ def __lt__(self, other): return self.data < other.data return NotImplemented + def __le__(self, other): return self == other or self < other + + def __gt__(self, other): return not (self <= other) + + def __ge__(self, other): return not (self < other) + def __repr__(self): return f'{type(self).__name__}({self.op}, {self.data!r})' @@ -354,7 +360,7 @@ def __add__(self, other): return normalize(r) if self.op is Op.COMPLEX and other.op in (Op.INTEGER, Op.REAL): return self + as_complex(other) - elif self.op is (Op.INTEGER, Op.REAL) and other.op is Op.COMPLEX: + elif self.op in (Op.INTEGER, Op.REAL) and other.op is Op.COMPLEX: return as_complex(self) + other elif self.op is Op.REAL and other.op is Op.INTEGER: return self + as_real(other, kind=self.data[1]) From 345c9d9c4dd479755eb298dda4316809a4739fea Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Fri, 3 Sep 2021 19:05:38 +0300 Subject: [PATCH 09/10] Add ternary, ref, and deref expressions support. --- numpy/f2py/crackfortran.py | 4 +- numpy/f2py/symbolic.py | 449 ++++++++++++++++++++---------- numpy/f2py/tests/test_symbolic.py | 67 +++-- 3 files changed, 348 insertions(+), 172 deletions(-) diff --git a/numpy/f2py/crackfortran.py b/numpy/f2py/crackfortran.py index 8cb1f73d2f9b..2adca9fba7e9 100755 --- a/numpy/f2py/crackfortran.py +++ b/numpy/f2py/crackfortran.py @@ -2609,8 +2609,8 @@ def analyzevars(block): # not defined in block['vars']. Here we assume # these correspond to Fortran/C intrinsic # functions or that are defined by other - # means. We'll let the compiler to validate - # the definiteness of such symbols. + # means. We'll let the compiler validate the + # definiteness of such symbols. dimension_exprs[d] = solver_and_deps vars[n]['dimension'].append(d) diff --git a/numpy/f2py/symbolic.py b/numpy/f2py/symbolic.py index 56faa309c294..0a2c3ba0471f 100644 --- a/numpy/f2py/symbolic.py +++ b/numpy/f2py/symbolic.py @@ -46,11 +46,14 @@ class Op(Enum): STRING = 20 ARRAY = 30 SYMBOL = 40 + TERNARY = 100 APPLY = 200 INDEXING = 210 CONCAT = 220 TERMS = 1000 FACTORS = 2000 + REF = 3000 + DEREF = 3001 class ArithOp(Enum): @@ -114,10 +117,10 @@ class Expr: """ @staticmethod - def parse(s): + def parse(s, language=Language.C): """Parse a Fortran expression to a Expr. """ - return fromstring(s) + return fromstring(s, language=language) def __init__(self, op, data): assert isinstance(op, Op) @@ -169,6 +172,12 @@ def __init__(self, op, data): assert isinstance(data, tuple) and len(data) == 2 # function is any hashable object assert hash(data[0]) is not None + elif op is Op.TERNARY: + # data is (, , ) + assert isinstance(data, tuple) and len(data) == 3 + elif op in (Op.REF, Op.DEREF): + # data is Expr instance + assert isinstance(data, Expr) else: raise NotImplementedError( f'unknown op or missing sanity check: {op}') @@ -327,6 +336,21 @@ def tostring(self, parent_precedence=Precedence.NONE, for arg in self.data] r = " // ".join(args) precedence = Precedence.PRODUCT + elif self.op is Op.TERNARY: + cond, expr1, expr2 = [a.tostring(Precedence.TUPLE, + language=language) + for a in self.data] + if language is Language.C: + return f'({cond} ? {expr1} : {expr2})' + if language is Language.Python: + return f'({expr1} if {cond} else {expr2})' + if language is Language.Fortran: + return f'merge({expr1}, {expr2}, {cond})' + raise NotImplementedError(f'tostring for {self.op} and {language}') + elif self.op is Op.REF: + return '&' + self.data.tostring(language=language) + elif self.op is Op.DEREF: + return '*' + self.data.tostring(language=language) else: raise NotImplementedError(f'tostring for op {self.op}') if parent_precedence.value > precedence.value: @@ -561,6 +585,12 @@ def substitute(self, symbols_map): func = func.substitute(symbols_map) args = tuple(a.substitute(symbols_map) for a in self.data[1:]) return normalize(Expr(self.op, (func,) + args)) + if self.op is Op.TERNARY: + operands = tuple(a.substitute(symbols_map) for a in self.data) + return normalize(Expr(self.op, operands)) + if self.op in (Op.REF, Op.DEREF): + return normalize(Expr(self.op, self.data.substitute(symbols_map))) + raise NotImplementedError(f'substitute method for {self.op}: {self!r}') def traverse(self, visit, *args, **kwargs): @@ -579,7 +609,7 @@ def traverse(self, visit, *args, **kwargs): if self.op in (Op.INTEGER, Op.REAL, Op.STRING, Op.SYMBOL): return self - elif self.op in (Op.COMPLEX, Op.ARRAY, Op.CONCAT): + elif self.op in (Op.COMPLEX, Op.ARRAY, Op.CONCAT, Op.TERNARY): return normalize(Expr(self.op, tuple( item.traverse(visit, *args, **kwargs) for item in self.data))) @@ -609,6 +639,9 @@ def traverse(self, visit, *args, **kwargs): indices = tuple(index.traverse(visit, *args, **kwargs) for index in self.data[1:]) return normalize(Expr(self.op, (obj,) + indices)) + elif self.op in (Op.REF, Op.DEREF): + return normalize(Expr(self.op, + self.data.traverse(visit, *args, **kwargs))) raise NotImplementedError(f'traverse method for {self.op}') def contains(self, other): @@ -813,6 +846,13 @@ def normalize(obj): if len(lst) == 1: return lst[0] return Expr(Op.CONCAT, tuple(lst)) + + if obj.op is Op.TERNARY: + cond, expr1, expr2 = map(normalize, obj.data) + if cond.op is Op.INTEGER: + return expr1 if cond.data[0] else expr2 + return Expr(Op.TERNARY, (cond, expr1, expr2)) + return obj @@ -905,6 +945,24 @@ def as_apply(func, *args, **kwargs): dict((k, as_expr(v)) for k, v in kwargs.items()))) +def as_ternary(cond, expr1, expr2): + """Return object as TERNARY expression (cond?expr1:expr2). + """ + return Expr(Op.TERNARY, (cond, expr1, expr2)) + + +def as_ref(expr): + """Return object as referencing expression. + """ + return Expr(Op.REF, expr) + + +def as_deref(expr): + """Return object as dereferencing expression. + """ + return Expr(Op.DEREF, expr) + + def as_terms(obj): """Return expression as TERMS expression. """ @@ -967,7 +1025,8 @@ def as_numer_denom(obj): """ if isinstance(obj, Expr): obj = normalize(obj) - if obj.op in (Op.INTEGER, Op.REAL, Op.COMPLEX, Op.SYMBOL, Op.INDEXING): + if obj.op in (Op.INTEGER, Op.REAL, Op.COMPLEX, Op.SYMBOL, + Op.INDEXING, Op.TERNARY): return obj, as_number(1) elif obj.op is Op.APPLY: if obj.data[0] is ArithOp.DIV and not obj.data[2]: @@ -1017,8 +1076,8 @@ def _counter(): COUNTER = _counter() -def replace_quotes(s): - """Replace quoted substrings of input. +def eliminate_quotes(s): + """Replace quoted substrings of input string. Return a new string and a mapping of replacements. """ @@ -1030,15 +1089,31 @@ def repl(m): # remove trailing underscore kind = kind[:-1] p = {"'": "SINGLE", '"': "DOUBLE"}[value[0]] - k = f'@__f2py_QUOTES_{p}_{COUNTER.__next__()}@' - d[as_symbol(k)] = as_string(value, kind or 1) + k = f'{kind}@__f2py_QUOTES_{p}_{COUNTER.__next__()}@' + d[k] = value return k - return (re.sub(r'({kind}_|)({single_quoted}|{double_quoted})'.format( + new_s = re.sub(r'({kind}_|)({single_quoted}|{double_quoted})'.format( kind=r'\w[\w\d_]*', single_quoted=r"('([^'\\]|(\\.))*')", double_quoted=r'("([^"\\]|(\\.))*")'), - repl, s), d) + repl, s) + + assert '"' not in new_s + assert "'" not in new_s + + return new_s, d + + +def insert_quotes(s, d): + """Inverse of eliminate_quotes. + """ + for k, v in d.items(): + kind = k[:k.find('@')] + if kind: + kind += '_' + s = s.replace(k, kind + v) + return s def replace_parenthesis(s): @@ -1084,21 +1159,32 @@ def replace_parenthesis(s): return s[:i] + k + r, d -def fromstring(s): - """Create an expression from string. +def _get_parenthesis_kind(s): + assert s.startswith('@__f2py_PARENTHESIS_'), s + return s.split('_')[4] - This is a "lazy" parser, that is, only arithmetic operations are - resolved, non-arithmetic operations are treated as symbols. + +def unreplace_parenthesis(s, d): + """Inverse of replace_parenthesis. """ - # We replace string literal constants that may include parenthesis - # that may confuse the arithmetic expression parser. - s, quotes_map = replace_quotes(s) + for k, v in d.items(): + p = _get_parenthesis_kind(k) + left = dict(ROUND='(', SQUARE='[', CURLY='{', ROUNDDIV='(/')[p] + right = dict(ROUND=')', SQUARE=']', CURLY='}', ROUNDDIV='/)')[p] + s = s.replace(k, left + v + right) + return s - # Apply arithmetic expression parser. - r = _fromstring_worker(s) - # Restore string constants. - return r.substitute(quotes_map) +def fromstring(s, language=Language.C): + """Create an expression from a string. + + This is a "lazy" parser, that is, only arithmetic operations are + resolved, non-arithmetic operations are treated as symbols. + """ + r = _FromStringWorker(language=language).parse(s) + if isinstance(r, Expr): + return r + raise ValueError(f'failed to parse `{s}` to Expr instance: got `{r}`') class _Pair: @@ -1120,151 +1206,210 @@ def __repr__(self): return f'{type(self).__name__}({self.left}, {self.right})' -def _fromstring_worker(s, dummy=None): - # Internal method used by fromstring to convert a string of - # arithmetic expressions to an expression tree. +class _FromStringWorker: - # Replace subexpression enclosed in parenthesis with dummy symbols - # and recursively parse subexpression into symbol-expression map: - r, raw_symbols_map = replace_parenthesis(s) - r = r.strip() + def __init__(self, language=Language.C): + self.original = None + self.quotes_map = None + self.language = language - symbols_map = dict([(as_symbol(k), _fromstring_worker(v)) - for k, v in raw_symbols_map.items()]) + def finalize_string(self, s): + return insert_quotes(s, self.quotes_map) - if ',' in r: - # comma-separate tuple of expressions - return tuple(_fromstring_worker(r_).substitute(symbols_map) - for r_ in r.split(',')) + def parse(self, inp): + self.original = inp + unquoted, self.quotes_map = eliminate_quotes(inp) + return self.process(unquoted) - m = re.match(r'\A(\w[\w\d_]*)\s*[=](.*)\Z', r) - if m: - # keyword argument - keyname, value = m.groups() - return _Pair(keyname, - _fromstring_worker(value).substitute(symbols_map)) - - operands = re.split(r'((? 1: - # Expression is an arithmetic sum - result = _fromstring_worker(operands[0] or '0') - for op, operand in zip(operands[1::2], operands[2::2]): - op = op.strip() - if op == '+': - result += _fromstring_worker(operand) - else: - assert op == '-' - result -= _fromstring_worker(operand) - return result.substitute(symbols_map) - - operands = re.split(r'//', r) - if len(operands) > 1: - # Expression is string concatenate operation - return Expr(Op.CONCAT, - tuple(_fromstring_worker(operand).substitute(symbols_map) - for operand in operands)) - - operands = re.split(r'([*]|/)', r.replace('**', '@__f2py_DOUBLE_STAR@')) - operands = [operand.replace('@__f2py_DOUBLE_STAR@', '**') - for operand in operands] - if len(operands) > 1: - # Expression is an arithmetic product - result = _fromstring_worker(operands[0]) - for op, operand in zip(operands[1::2], operands[2::2]): - op = op.strip() - if op == '*': - result *= _fromstring_worker(operand) + def process(self, s, context='expr'): + """Parse string within the given context. + + The context may define the result in case of ambiguous + expressions. For instance, consider expressions `f(x, y)` and + `(x, y) + (a, b)` where `f` is a function and pair `(x, y)` + denotes complex number. Specifying context as "args" or + "expr", the subexpression `(x, y)` will be parse to an + argument list or to a complex number, respectively. + """ + if isinstance(s, (list, tuple)): + return type(s)(self.process(s_, context) for s_ in s) + + assert isinstance(s, str), (type(s), s) + + # replace subexpressions in parenthesis with f2py @-names + r, raw_symbols_map = replace_parenthesis(s) + r = r.strip() + + def restore(r): + # restores subexpressions marked with f2py @-names + if isinstance(r, (list, tuple)): + return type(r)(map(restore, r)) + return unreplace_parenthesis(r, raw_symbols_map) + + # comma-separated tuple + if ',' in r: + operands = restore(r.split(',')) + if context == 'args': + return tuple(self.process(operands)) + if context == 'expr': + if len(operands) == 2: + # complex number literal + return as_complex(*self.process(operands)) + raise NotImplementedError( + f'parsing comma-separated list (context={context}): {r}') + return tuple(self.process(restore(r.split(',')), context)) + + # ternary operation + m = re.match(r'\A([^?]+)[?]([^:]+)[:](.+)\Z', r) + if m: + assert context == 'expr', context + oper, expr1, expr2 = restore(m.groups()) + if 0: + # TODO: enable this when support for boolean + # expressions is fully implemented + oper = self.process(oper) else: - assert op == '/' - result /= _fromstring_worker(operand) - return result.substitute(symbols_map) - - operands = list(reversed(re.split(r'[*][*]', r))) - if len(operands) > 1: - # Expression is an arithmetic exponentiation - result = _fromstring_worker(operands[0]) - for operand in operands[1:]: - result = _fromstring_worker(operand) ** result - return result - - m = re.match(r'\A({digit_string})({kind}|)\Z'.format( - digit_string=r'\d+', - kind=r'_(\d+|\w[\w\d_]*)'), r) - if m: - # Expression is int-literal-constant - value, _, kind = m.groups() - if kind and kind.isdigit(): - kind = int(kind) - return as_integer(int(value), kind or 4) - - m = re.match(r'\A({significant}({exponent}|)|\d+{exponent})({kind}|)\Z' - .format( - significant=r'[.]\d+|\d+[.]\d*', - exponent=r'[edED][+-]?\d+', - kind=r'_(\d+|\w[\w\d_]*)'), r) - if m: - # Expression is real-literal-constant - value, _, _, kind = m.groups() - if kind and kind.isdigit(): - kind = int(kind) - value = value.lower() - if 'd' in value: - return as_real(float(value.replace('d', 'e')), kind or 8) - return as_real(float(value), kind or 4) - - m = re.match(r'\A(\w[\w\d_]*)_(@__f2py_QUOTES_(\w+)_\d+@)\Z', r) - if m: + oper = as_symbol(self.finalize_string(oper)) + expr1 = self.process(expr1) + expr2 = self.process(expr2) + return as_ternary(oper, expr1, expr2) + + # keyword argument + m = re.match(r'\A(\w[\w\d_]*)\s*[=](.*)\Z', r) + if m: + keyname, value = m.groups() + value = restore(value) + return _Pair(keyname, self.process(value)) + + # addition/subtraction operations + operands = re.split(r'((? 1: + result = self.process(restore(operands[0] or '0')) + for op, operand in zip(operands[1::2], operands[2::2]): + operand = self.process(restore(operand)) + op = op.strip() + if op == '+': + result += operand + else: + assert op == '-' + result -= operand + return result + + # string concatenate operation + if self.language is Language.Fortran and '//' in r: + operands = restore(r.split('//')) + return Expr(Op.CONCAT, + tuple(self.process(operands))) + + # multiplication/division operations + operands = re.split(r'(?<=[@\w\d_])\s*([*]|/)', + (r if self.language is Language.C + else r.replace('**', '@__f2py_DOUBLE_STAR@'))) + if len(operands) > 1: + operands = restore(operands) + if self.language is not Language.C: + operands = [operand.replace('@__f2py_DOUBLE_STAR@', '**') + for operand in operands] + # Expression is an arithmetic product + result = self.process(operands[0]) + for op, operand in zip(operands[1::2], operands[2::2]): + operand = self.process(operand) + op = op.strip() + if op == '*': + result *= operand + else: + assert op == '/' + result /= operand + return result + + # referencing/dereferencing + if r.startswith('*') or r.startswith('&'): + op = {'*': Op.DEREF, '&': Op.REF}[r[0]] + operand = self.process(restore(r[1:])) + return Expr(op, operand) + + # exponentiation operations + if self.language is not Language.C and '**' in r: + operands = list(reversed(restore(r.split('**')))) + result = self.process(operands[0]) + for operand in operands[1:]: + operand = self.process(operand) + result = operand ** result + return result + + # int-literal-constant + m = re.match(r'\A({digit_string})({kind}|)\Z'.format( + digit_string=r'\d+', + kind=r'_(\d+|\w[\w\d_]*)'), r) + if m: + value, _, kind = m.groups() + if kind and kind.isdigit(): + kind = int(kind) + return as_integer(int(value), kind or 4) + + # real-literal-constant + m = re.match(r'\A({significant}({exponent}|)|\d+{exponent})({kind}|)\Z' + .format( + significant=r'[.]\d+|\d+[.]\d*', + exponent=r'[edED][+-]?\d+', + kind=r'_(\d+|\w[\w\d_]*)'), r) + if m: + value, _, _, kind = m.groups() + if kind and kind.isdigit(): + kind = int(kind) + value = value.lower() + if 'd' in value: + return as_real(float(value.replace('d', 'e')), kind or 8) + return as_real(float(value), kind or 4) + # string-literal-constant with kind parameter specification - kind, name, _ = m.groups() - return as_string(name, kind) - - m = re.match(r'\A(\w[\w\d+]*)\s*(@__f2py_PARENTHESIS_(\w+)_\d+@)\Z', r) - if m: - # apply or indexing expression - name, args, paren = m.groups() - target = as_symbol(name).substitute(symbols_map) - if args in raw_symbols_map: - args = symbols_map[as_symbol(args)] + if r in self.quotes_map: + kind = r[:r.find('@')] + return as_string(self.quotes_map[r], kind or 1) + + # array constructor or literal complex constant or + # parenthesized expression + if r in raw_symbols_map: + paren = _get_parenthesis_kind(r) + items = self.process(restore(raw_symbols_map[r]), + 'expr' if paren == 'ROUND' else 'args') + if paren == 'ROUND': + if isinstance(items, Expr): + return items + if paren in ['ROUNDDIV', 'SQUARE']: + # Expression is a array constructor + if isinstance(items, Expr): + items = (items,) + return as_array(items) + + # function call/indexing + m = re.match(r'\A(.+)\s*(@__f2py_PARENTHESIS_(ROUND|SQUARE)_\d+@)\Z', + r) + if m: + target, args, paren = m.groups() + target = self.process(restore(target)) + args = self.process(restore(args)[1:-1], 'args') if not isinstance(args, tuple): args = args, if paren == 'ROUND': - # Expression is a function call or a Fortran indexing - # operation kwargs = dict((a.left, a.right) for a in args if isinstance(a, _Pair)) args = tuple(a for a in args if not isinstance(a, _Pair)) + # Warning: this could also be Fortran indexing operation.. return as_apply(target, *args, **kwargs) - if paren == 'SQUARE': - # Expression is a C indexing operation (e.g. used in - # .pyf files) + else: + # Expression is a C/Python indexing operation + # (e.g. used in .pyf files) + assert paren == 'SQUARE' return target[args] - # A non-Fortran/C expression? - assert 0, (s, r) - - m = re.match(r'\A(@__f2py_PARENTHESIS_(\w+)_\d+@)\Z', r) - if m: - # array constructor or literal complex constant - items, paren = m.groups() - if items in raw_symbols_map: - items = symbols_map[as_symbol(items)] - if (paren == 'ROUND' and isinstance(items, - tuple) and len(items) == 2): - # Expression is a literal complex constant - return as_complex(items[0], items[1]) - if paren in ['ROUNDDIV', 'SQUARE']: - # Expression is a array constructor - return as_array(items) - assert 0, (r, items, paren) - m = re.match(r'\A\w[\w\d_]*\Z', r) - if m: - # Fortran standard conforming name - return as_symbol(r) + # Fortran standard conforming identifier + m = re.match(r'\A\w[\w\d_]*\Z', r) + if m: + return as_symbol(r) - m = re.match(r'\A@\w[\w\d_]*@\Z', r) - if m: - # f2py special dummy name + # fall-back to symbol + r = self.finalize_string(restore(r)) + ewarn( + f'fromstring: treating {r!r} as symbol (original={self.original})') return as_symbol(r) - - ewarn(f'fromstring: treating {r!r} as symbol') - return as_symbol(r) diff --git a/numpy/f2py/tests/test_symbolic.py b/numpy/f2py/tests/test_symbolic.py index 27244c995636..61f0d4b3ad26 100644 --- a/numpy/f2py/tests/test_symbolic.py +++ b/numpy/f2py/tests/test_symbolic.py @@ -2,34 +2,35 @@ from numpy.f2py.symbolic import ( Expr, Op, ArithOp, Language, as_symbol, as_number, as_string, as_array, as_complex, - as_terms, as_factors, fromstring, replace_quotes, as_expr, as_apply, - as_numer_denom + as_terms, as_factors, eliminate_quotes, insert_quotes, + fromstring, as_expr, as_apply, + as_numer_denom, as_ternary, as_ref, as_deref, + normalize ) from . import util class TestSymbolic(util.F2PyTest): - def test_replace_quotes(self): - + def test_eliminate_quotes(self): def worker(s): - r, d = replace_quotes(s) - assert '"' not in r, r - assert "'" not in r, r - # undo replacements: - s1 = r - for k, v in d.items(): - s1 = s1.replace(k.data, v.data[0]) + r, d = eliminate_quotes(s) + s1 = insert_quotes(r, d) assert s1 == s - worker('"1234" // "ABCD"') - worker('"1234" // \'ABCD\'') - worker('"1\\"2\'AB\'34"') - worker("'1\\'2\"AB\"34'") + for kind in ['', 'mykind_']: + worker(kind + '"1234" // "ABCD"') + worker(kind + '"1234" // ' + kind + '"ABCD"') + worker(kind + '"1234" // \'ABCD\'') + worker(kind + '"1234" // ' + kind + '\'ABCD\'') + worker(kind + '"1\\"2\'AB\'34"') + worker('a = ' + kind + "'1\\'2\"AB\"34'") def test_sanity(self): x = as_symbol('x') y = as_symbol('y') + z = as_symbol('z') + assert x.op == Op.SYMBOL assert repr(x) == "Expr(Op.SYMBOL, 'x')" assert x == x @@ -92,9 +93,17 @@ def test_sanity(self): assert w != v assert hash(v) is not None + t = as_ternary(x, y, z) + u = as_ternary(x, z, y) + assert t.op == Op.TERNARY + assert t == t + assert t != u + assert hash(t) is not None + def test_tostring_fortran(self): x = as_symbol('x') y = as_symbol('y') + z = as_symbol('z') n = as_number(123) m = as_number(456) a = as_array((n, m)) @@ -132,10 +141,13 @@ def test_tostring_fortran(self): assert str(Expr(Op.APPLY, ('f', (x, y), {}))) == 'f(x, y)' assert str(Expr(Op.INDEXING, ('f', x))) == 'f[x]' + assert str(as_ternary(x, y, z)) == 'merge(y, z, x)' + def test_tostring_c(self): language = Language.C x = as_symbol('x') y = as_symbol('y') + z = as_symbol('z') n = as_number(123) assert Expr(Op.FACTORS, {x: 2}).tostring(language=language) == 'x * x' @@ -153,6 +165,8 @@ def test_tostring_c(self): assert (x + (x - y) / (x + y) + n).tostring( language=language) == '123 + x + (x - y) / (x + y)' + assert as_ternary(x, y, z).tostring(language=language) == '(x ? y : z)' + def test_operations(self): x = as_symbol('x') y = as_symbol('y') @@ -224,6 +238,9 @@ def test_substitute(self): assert x.substitute({x: y + z}) == y + z assert a.substitute({x: y + z}) == as_array((y + z, y)) + assert as_ternary(x, y, z).substitute( + {x: y + z}) == as_ternary(y + z, y, z) + def test_fromstring(self): x = as_symbol('x') @@ -242,16 +259,20 @@ def test_fromstring(self): assert fromstring('x * y') == x * y assert fromstring('x * 2') == x * 2 assert fromstring('x / y') == x / y - assert fromstring('x ** 2') == x ** 2 - assert fromstring('x ** 2 ** 3') == x ** 2 ** 3 + assert fromstring('x ** 2', + language=Language.Python) == x ** 2 + assert fromstring('x ** 2 ** 3', + language=Language.Python) == x ** 2 ** 3 assert fromstring('(x + y) * z') == (x + y) * z assert fromstring('f(x)') == f(x) assert fromstring('f(x,y)') == f(x, y) assert fromstring('f[x]') == f[x] + assert fromstring('f[x][y]') == f[x][y] assert fromstring('"ABC"') == s - assert fromstring('"ABC" // "123" ') == s // t + assert normalize(fromstring('"ABC" // "123" ', + language=Language.Fortran)) == s // t assert fromstring('f("ABC")') == f(s) assert fromstring('MYSTRKIND_"ABC"') == as_string('"ABC"', 'MYSTRKIND') @@ -288,6 +309,16 @@ def test_fromstring(self): age=as_number(50), shape=as_array((as_number(34), as_number(23))))) + assert fromstring('x?y:z') == as_ternary(x, y, z) + + assert fromstring('*x') == as_deref(x) + assert fromstring('**x') == as_deref(as_deref(x)) + assert fromstring('&x') == as_ref(x) + assert fromstring('(*x) * (*y)') == as_deref(x) * as_deref(y) + assert fromstring('(*x) * *y') == as_deref(x) * as_deref(y) + assert fromstring('*x * *y') == as_deref(x) * as_deref(y) + assert fromstring('*x**y') == as_deref(x) * as_deref(y) + def test_traverse(self): x = as_symbol('x') y = as_symbol('y') From 88f988ae6226ffb5e3c9ae41a2d0cf9fd982109a Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Fri, 3 Sep 2021 21:56:18 +0300 Subject: [PATCH 10/10] Fix a bug in dependencies. Implement support for relational operations. --- numpy/f2py/crackfortran.py | 7 +- numpy/f2py/symbolic.py | 141 +++++++++++++++++++++++++----- numpy/f2py/tests/test_symbolic.py | 38 +++++++- 3 files changed, 157 insertions(+), 29 deletions(-) diff --git a/numpy/f2py/crackfortran.py b/numpy/f2py/crackfortran.py index 2adca9fba7e9..3066657206c1 100755 --- a/numpy/f2py/crackfortran.py +++ b/numpy/f2py/crackfortran.py @@ -2701,11 +2701,6 @@ def analyzevars(block): 'required' if is_required else 'optional') if v_attr: vars[v]['attrspec'] = v_attr - else: - # n is output or hidden argument, hence it - # will depend on all variables in d - n_deps.extend(coeffs_and_deps) - if coeffs_and_deps is not None: # extend v dependencies with ones specified in attrspec for v, (solver, deps) in coeffs_and_deps.items(): @@ -2716,6 +2711,8 @@ def analyzevars(block): v_deps.extend(aa[7:-1].split(',')) if v_deps: vars[v]['depend'] = list(set(v_deps)) + if n not in v_deps: + n_deps.append(v) elif isstring(vars[n]): if 'charselector' in vars[n]: if '*' in vars[n]['charselector']: diff --git a/numpy/f2py/symbolic.py b/numpy/f2py/symbolic.py index 0a2c3ba0471f..b747a75f916b 100644 --- a/numpy/f2py/symbolic.py +++ b/numpy/f2py/symbolic.py @@ -15,7 +15,6 @@ # # TODO: support logical constants (Op.BOOLEAN) # TODO: support logical operators (.AND., ...) -# TODO: support relational operators (<, >, ..., .LT., ...) # TODO: support defined operators (.MYOP., ...) # __all__ = ['Expr'] @@ -50,12 +49,43 @@ class Op(Enum): APPLY = 200 INDEXING = 210 CONCAT = 220 + RELATIONAL = 300 TERMS = 1000 FACTORS = 2000 REF = 3000 DEREF = 3001 +class RelOp(Enum): + """ + Used in Op.RELATIONAL expression to specify the function part. + """ + EQ = 1 + NE = 2 + LT = 3 + LE = 4 + GT = 5 + GE = 6 + + @classmethod + def fromstring(cls, s, language=Language.C): + if language is Language.Fortran: + return {'.eq.': RelOp.EQ, '.ne.': RelOp.NE, + '.lt.': RelOp.LT, '.le.': RelOp.LE, + '.gt.': RelOp.GT, '.ge.': RelOp.GE}[s.lower()] + return {'==': RelOp.EQ, '!=': RelOp.NE, '<': RelOp.LT, + '<=': RelOp.LE, '>': RelOp.GT, '>=': RelOp.GE}[s] + + def tostring(self, language=Language.C): + if language is Language.Fortran: + return {RelOp.EQ: '.eq.', RelOp.NE: '.ne.', + RelOp.LT: '.lt.', RelOp.LE: '.le.', + RelOp.GT: '.gt.', RelOp.GE: '.ge.'}[self] + return {RelOp.EQ: '==', RelOp.NE: '!=', + RelOp.LT: '<', RelOp.LE: '<=', + RelOp.GT: '>', RelOp.GE: '>='}[self] + + class ArithOp(Enum): """ Used in Op.APPLY expression to specify the function part. @@ -77,12 +107,19 @@ class Precedence(Enum): """ Used as Expr.tostring precedence argument. """ - NONE = 0 - TUPLE = 1 - SUM = 2 + ATOM = 0 + POWER = 1 + UNARY = 2 PRODUCT = 3 - POWER = 4 - ATOM = 5 + SUM = 4 + LT = 6 + EQ = 7 + LAND = 11 + LOR = 12 + TERNARY = 13 + ASSIGN = 14 + TUPLE = 15 + NONE = 100 integer_types = (int,) @@ -178,6 +215,9 @@ def __init__(self, op, data): elif op in (Op.REF, Op.DEREF): # data is Expr instance assert isinstance(data, Expr) + elif op is Op.RELATIONAL: + # data is (, , ) + assert isinstance(data, tuple) and len(data) == 3 else: raise NotImplementedError( f'unknown op or missing sanity check: {op}') @@ -341,19 +381,32 @@ def tostring(self, parent_precedence=Precedence.NONE, language=language) for a in self.data] if language is Language.C: - return f'({cond} ? {expr1} : {expr2})' - if language is Language.Python: - return f'({expr1} if {cond} else {expr2})' - if language is Language.Fortran: - return f'merge({expr1}, {expr2}, {cond})' - raise NotImplementedError(f'tostring for {self.op} and {language}') + r = f'({cond} ? {expr1} : {expr2})' + elif language is Language.Python: + r = f'({expr1} if {cond} else {expr2})' + elif language is Language.Fortran: + r = f'merge({expr1}, {expr2}, {cond})' + else: + raise NotImplementedError( + f'tostring for {self.op} and {language}') + precedence = Precedence.ATOM elif self.op is Op.REF: - return '&' + self.data.tostring(language=language) + r = '&' + self.data.tostring(Precedence.UNARY, language=language) + precedence = Precedence.UNARY elif self.op is Op.DEREF: - return '*' + self.data.tostring(language=language) + r = '*' + self.data.tostring(Precedence.UNARY, language=language) + precedence = Precedence.UNARY + elif self.op is Op.RELATIONAL: + rop, left, right = self.data + precedence = (Precedence.EQ if rop in (RelOp.EQ, RelOp.NE) + else Precedence.LT) + left = left.tostring(precedence, language=language) + right = right.tostring(precedence, language=language) + rop = rop.tostring(language=language) + r = f'{left} {rop} {right}' else: raise NotImplementedError(f'tostring for op {self.op}') - if parent_precedence.value > precedence.value: + if parent_precedence.value < precedence.value: # If parent precedence is higher than operand precedence, # operand will be enclosed in parenthesis. return '(' + r + ')' @@ -590,7 +643,11 @@ def substitute(self, symbols_map): return normalize(Expr(self.op, operands)) if self.op in (Op.REF, Op.DEREF): return normalize(Expr(self.op, self.data.substitute(symbols_map))) - + if self.op is Op.RELATIONAL: + rop, left, right = self.data + left = left.substitute(symbols_map) + right = right.substitute(symbols_map) + return normalize(Expr(self.op, (rop, left, right))) raise NotImplementedError(f'substitute method for {self.op}: {self!r}') def traverse(self, visit, *args, **kwargs): @@ -642,6 +699,11 @@ def traverse(self, visit, *args, **kwargs): elif self.op in (Op.REF, Op.DEREF): return normalize(Expr(self.op, self.data.traverse(visit, *args, **kwargs))) + elif self.op is Op.RELATIONAL: + rop, left, right = self.data + left = left.traverse(visit, *args, **kwargs) + right = right.traverse(visit, *args, **kwargs) + return normalize(Expr(self.op, (rop, left, right))) raise NotImplementedError(f'traverse method for {self.op}') def contains(self, other): @@ -963,6 +1025,30 @@ def as_deref(expr): return Expr(Op.DEREF, expr) +def as_eq(left, right): + return Expr(Op.RELATIONAL, (RelOp.EQ, left, right)) + + +def as_ne(left, right): + return Expr(Op.RELATIONAL, (RelOp.NE, left, right)) + + +def as_lt(left, right): + return Expr(Op.RELATIONAL, (RelOp.LT, left, right)) + + +def as_le(left, right): + return Expr(Op.RELATIONAL, (RelOp.LE, left, right)) + + +def as_gt(left, right): + return Expr(Op.RELATIONAL, (RelOp.GT, left, right)) + + +def as_ge(left, right): + return Expr(Op.RELATIONAL, (RelOp.GE, left, right)) + + def as_terms(obj): """Return expression as TERMS expression. """ @@ -1257,23 +1343,32 @@ def restore(r): return as_complex(*self.process(operands)) raise NotImplementedError( f'parsing comma-separated list (context={context}): {r}') - return tuple(self.process(restore(r.split(',')), context)) # ternary operation m = re.match(r'\A([^?]+)[?]([^:]+)[:](.+)\Z', r) if m: assert context == 'expr', context oper, expr1, expr2 = restore(m.groups()) - if 0: - # TODO: enable this when support for boolean - # expressions is fully implemented - oper = self.process(oper) - else: - oper = as_symbol(self.finalize_string(oper)) + oper = self.process(oper) expr1 = self.process(expr1) expr2 = self.process(expr2) return as_ternary(oper, expr1, expr2) + # relational expression + if self.language is Language.Fortran: + m = re.match( + r'\A(.+)\s*[.](eq|ne|lt|le|gt|ge)[.]\s*(.+)\Z', r, re.I) + else: + m = re.match( + r'\A(.+)\s*([=][=]|[!][=]|[<][=]|[<]|[>][=]|[>])\s*(.+)\Z', r) + if m: + left, rop, right = m.groups() + if self.language is Language.Fortran: + rop = '.' + rop + '.' + left, right = self.process(restore((left, right))) + rop = RelOp.fromstring(rop, language=self.language) + return Expr(Op.RELATIONAL, (rop, left, right)) + # keyword argument m = re.match(r'\A(\w[\w\d_]*)\s*[=](.*)\Z', r) if m: diff --git a/numpy/f2py/tests/test_symbolic.py b/numpy/f2py/tests/test_symbolic.py index 61f0d4b3ad26..52cabac530d5 100644 --- a/numpy/f2py/tests/test_symbolic.py +++ b/numpy/f2py/tests/test_symbolic.py @@ -5,7 +5,7 @@ as_terms, as_factors, eliminate_quotes, insert_quotes, fromstring, as_expr, as_apply, as_numer_denom, as_ternary, as_ref, as_deref, - normalize + normalize, as_eq, as_ne, as_lt, as_gt, as_le, as_ge ) from . import util @@ -100,6 +100,13 @@ def test_sanity(self): assert t != u assert hash(t) is not None + e = as_eq(x, y) + f = as_lt(x, y) + assert e.op == Op.RELATIONAL + assert e == e + assert e != f + assert hash(e) is not None + def test_tostring_fortran(self): x = as_symbol('x') y = as_symbol('y') @@ -142,6 +149,12 @@ def test_tostring_fortran(self): assert str(Expr(Op.INDEXING, ('f', x))) == 'f[x]' assert str(as_ternary(x, y, z)) == 'merge(y, z, x)' + assert str(as_eq(x, y)) == 'x .eq. y' + assert str(as_ne(x, y)) == 'x .ne. y' + assert str(as_lt(x, y)) == 'x .lt. y' + assert str(as_le(x, y)) == 'x .le. y' + assert str(as_gt(x, y)) == 'x .gt. y' + assert str(as_ge(x, y)) == 'x .ge. y' def test_tostring_c(self): language = Language.C @@ -166,6 +179,12 @@ def test_tostring_c(self): language=language) == '123 + x + (x - y) / (x + y)' assert as_ternary(x, y, z).tostring(language=language) == '(x ? y : z)' + assert as_eq(x, y).tostring(language=language) == 'x == y' + assert as_ne(x, y).tostring(language=language) == 'x != y' + assert as_lt(x, y).tostring(language=language) == 'x < y' + assert as_le(x, y).tostring(language=language) == 'x <= y' + assert as_gt(x, y).tostring(language=language) == 'x > y' + assert as_ge(x, y).tostring(language=language) == 'x >= y' def test_operations(self): x = as_symbol('x') @@ -240,6 +259,8 @@ def test_substitute(self): assert as_ternary(x, y, z).substitute( {x: y + z}) == as_ternary(y + z, y, z) + assert as_eq(x, y).substitute( + {x: y + z}) == as_eq(y + z, y) def test_fromstring(self): @@ -319,6 +340,20 @@ def test_fromstring(self): assert fromstring('*x * *y') == as_deref(x) * as_deref(y) assert fromstring('*x**y') == as_deref(x) * as_deref(y) + assert fromstring('x == y') == as_eq(x, y) + assert fromstring('x != y') == as_ne(x, y) + assert fromstring('x < y') == as_lt(x, y) + assert fromstring('x > y') == as_gt(x, y) + assert fromstring('x <= y') == as_le(x, y) + assert fromstring('x >= y') == as_ge(x, y) + + assert fromstring('x .eq. y', language=Language.Fortran) == as_eq(x, y) + assert fromstring('x .ne. y', language=Language.Fortran) == as_ne(x, y) + assert fromstring('x .lt. y', language=Language.Fortran) == as_lt(x, y) + assert fromstring('x .gt. y', language=Language.Fortran) == as_gt(x, y) + assert fromstring('x .le. y', language=Language.Fortran) == as_le(x, y) + assert fromstring('x .ge. y', language=Language.Fortran) == as_ge(x, y) + def test_traverse(self): x = as_symbol('x') y = as_symbol('y') @@ -340,6 +375,7 @@ def replace_visit(s, r=z): assert (x + y + z).traverse(replace_visit) == (2 * z + y) assert (x + f(y, x - z)).traverse( replace_visit) == (z + f(y, as_number(0))) + assert as_eq(x, y).traverse(replace_visit) == as_eq(z, y) # Use traverse to collect symbols, method 1 function_symbols = set()