Skip to content

Commit

Permalink
start tests
Browse files Browse the repository at this point in the history
Signed-off-by: Nathaniel Starkman (@nstarman) <nstarkman@protonmail.com>
  • Loading branch information
nstarman committed Jul 21, 2021
1 parent 17338ea commit f1b9e51
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 21 deletions.
16 changes: 9 additions & 7 deletions astropy/units/quantity_helper/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,14 +322,16 @@ def helper_clip(f, unit1, unit2, unit3):


# HELPER NARGS
def register_ufunc(ufunc, inunits, ounits, assume_correct_units=False):
def register_ufunc(ufunc, inunits, ounits, *, assume_correct_units=False):
"""
Register `~numpy.ufunc` in ``UFUNC_HELPERS``, along with the conversion
functions necessary to strip input units and assign output units. ufuncs
operate on only recognized `~numpy.dtype`s (e.g. int, float32), so units
must be removed beforehand and replaced afterwards. Therefore units MUST
BE KNOWN a priori, as they will not be propagated by the astropy machinery.
Note that ufuncs always return an object array.
Parameters
----------
ufunc : `~numpy.ufunc`
Expand All @@ -351,10 +353,11 @@ def register_ufunc(ufunc, inunits, ounits, assume_correct_units=False):
from astropy.units import Unit, dimensionless_unscaled

# process sequence[unit-like] -> sequence[unit]
if isiterable(inunits):
inunits = [(Unit(iu) if iu is not None else iu) for iu in inunits]
if isiterable(ounits):
ounits = [(Unit(ou) if ou is not None else ou) for ou in ounits]
inunits = [(Unit(iu) if iu is not None else None) for iu in inunits]

ounits = [(Unit(ou) if ou is not None else None) for ou in ounits]
ounits = ounits[0] if len(ounits) == 1 else ounits

# backup units for interpreting array (no units) input
fallbackinunits = (inunits if assume_correct_units
else [dimensionless_unscaled] * len(inunits))
Expand All @@ -376,8 +379,7 @@ def helper_nargs(f, *units):
ounits : sequence[unit-like]
"""
# unit converters
# no units assumed to be in fallback units
# unit converters. No units = assumed to be in fallback units.
# if None in 'inunits', skip conversion
converters = [(get_converter(frm or fb, to) if to is not None else None)
for frm, to, fb in zip(units, inunits, fallbackinunits)]
Expand Down
247 changes: 246 additions & 1 deletion astropy/units/tests/test_quantity_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# returns quantities with the right units, or raises exceptions.

import concurrent.futures
import inspect
import warnings
from collections import namedtuple

Expand All @@ -11,9 +12,11 @@
from erfa import ufunc as erfa_ufunc

from astropy import units as u
from astropy.tests.helper import assert_quantity_allclose
from astropy.units import quantity_helper as qh
from astropy.units.utils import quantity_frompyfunc, _is_ulike
from astropy.units.quantity_helper.converters import UfuncHelpers
from astropy.units.quantity_helper.helpers import helper_sqrt
from astropy.units.quantity_helper.helpers import helper_sqrt, register_ufunc
from astropy.utils.compat.optional_deps import HAS_SCIPY # noqa


Expand Down Expand Up @@ -1338,3 +1341,245 @@ def test_jv_invalid_units(self, function):
assert exc.value.args[0] == ("Can only apply '{}' function to "
"dimensionless quantities"
.format(function.__name__))


# -------------------------------------------------------------------

def make_ufunc_list():

def func_1_1(x: "km") -> "km2":
return x**2

def funcobj_1_1(x: "km") -> object:
return x**2

def func_2_1(x: "km", y) -> "km2":
return x**2

def func_1_2(x: "km") -> ("km2", "km3"):
return x**2, x**3

def func_2_2(x: "km", y: "km") -> ("km2", None):
return x**2, y**2

def funcobj_2_2(x: "km", y: "km") -> (object, object):
return x**2, y**2

ufunc_list = np.array(
# ( func, (input,), (output,) )
[(func_1_1, (2,), (4,)),
(funcobj_1_1, (2,), (4,)),
(func_2_1, (2, 1), (4,)),
(func_2_1, (2, None), (4,)),
(func_1_2, (2,), (4, 8)),
(func_2_2, (2, 3), (4, 9)),
(funcobj_2_2, (2, 3), (4, 9)),
# --- arrays --- #
(func_1_1, ([2, 3, 4],), [4, 9, 16]),
(funcobj_1_1, ([2, 3, 4],), [4, 9, 16]),
(func_2_1, ([2, 3, 4], 1), [4, 9, 16]),
(func_1_2, ([2, 3, 4],), ([4, 9, 16], [8, 27, 64])),
(func_2_2, ([2, 3, 4],[5, 6, 7]), ([4, 9, 16], [25, 36, 49])),
(funcobj_2_2, ([2, 3, 4],[5, 6, 7]), ([4, 9, 16], [25, 36, 49])),
],
dtype=object)

return ufunc_list


ufunc_list = make_ufunc_list()


class TestRegisterUfunc:

def setup_class(self):
# registry of normal ufuncs
self.ufunc_registry = {
func: np.frompyfunc(func, *map(int, func.__name__.split("_")[1:]))
for func in ufunc_list[:, 0]}

# registry of normal ufuncs, where the units will be assumed correct
self.ufunc_assume_registry = {
func: np.frompyfunc(func, *map(int, func.__name__.split("_")[1:]))
for func in ufunc_list[:, 0]}

# register as quantity-ufuncs
# NOTE! output units are "u.km" even though the functions are squaring
# or cubing the input. `register_ufunc` must be told the units since it
# can't introspect to find the correct units.
# This setup also tests `if isiterable(in/ounits)` in register_ufunc.
for func in self.ufunc_registry.keys():
ufunc = self.ufunc_registry[func]
register_ufunc(ufunc, inunits=[u.km] * ufunc.nin,
ounits=[u.km] * ufunc.nout)
# same as above, but assuming unitless has 'inunits'
ufunc = self.ufunc_assume_registry[func]
register_ufunc(ufunc, inunits=[u.km] * ufunc.nin,
ounits=[u.km] * ufunc.nout,
assume_correct_units=True)

@pytest.mark.parametrize("func, inp, res", ufunc_list)
def test_raw_ufunc(self, func, inp, res):
"""In this case, the output will also not have units."""
got = self.ufunc_registry[func](*inp)
# need to convert from an object array to float array, for comparison
# also, for multiple output, need to make array, before type casting
got = np.array(got).astype(float) if isinstance(got, (np.ndarray, tuple)) else got
assert_allclose(got, res)

@pytest.mark.parametrize("func, inp, res", ufunc_list)
def test_no_units(self, func, inp, res):
"""Test unitless input has unitless output. NO UNITS ATTACHED!"""
got = self.ufunc_registry[func](*inp) # no units
# need to convert from an object array to float array, for comparison
# also, for multiple output, need to make array, before type casting
got = np.array(got).astype(float) if isinstance(got, (np.ndarray, tuple)) else got
assert_allclose(got, res)

@pytest.mark.parametrize("func, inp, res", ufunc_list)
def test_has_units(self, func, inp, res):
"""Test unitful input has unitful output."""
inp = [(x * u.km if x is not None else None) for x in inp]

if None in inp:
with pytest.raises(u.UnitConversionError, match="not convertible"):
got = self.ufunc_registry[func](*inp)
return

got = self.ufunc_registry[func](*inp)
# need to convert from an object array to float array, for comparison
# also, for multiple output, need to make array, before type casting
got = (u.Quantity(got, dtype=float) if isinstance(got, (np.ndarray, tuple)) else got)
assert_quantity_allclose(got, res * u.km)

@pytest.mark.parametrize("func, inp, res", ufunc_list)
def test_has_units_assumed_correct(self, func, inp, res):
"""
Test unitful input has unitful output and units are assumed to be
correct if not given. This only applies."""
# partial unit assignment to trigger unit assumptions.
# single inputs can't trigger unit assumptions.
inp = [inp[0] * u.km, *inp[1:]] if len(inp) > 1 else inp * u.km

got = self.ufunc_assume_registry[func](*inp)
# need to convert from an object array to float array, for comparison
# also, for multiple output, need to make array, before type casting
got = (u.Quantity(got, dtype=float) if isinstance(got, (np.ndarray, tuple)) else got)
assert_quantity_allclose(got, res * u.km)

@pytest.mark.parametrize("registry",
["ufunc_registry", "ufunc_assume_registry"])
@pytest.mark.parametrize("func, inp, res", ufunc_list)
def test_has_wrong_units(self, registry, func, inp, res):
"""Test wrong unitful input raises errors."""
inp = [(x * u.deg if x is not None else None) for x in inp]

with pytest.raises(u.UnitConversionError, match="'deg'"):
got = getattr(self, registry)[func](*inp)


class TestQuantityFromPyFunc:

def setup_class(self):
# registry of ufuncs
self.ufunc_registry = {}
self.ufunc_assume_registry = {}
self.ufunc_introspect_registry = {}
self.ufunc_introspect_assume_registry = {}
for func in ufunc_list[:, 0]:
nin, nout = map(int, func.__name__.split("_")[1:])

self.ufunc_registry[func] = quantity_frompyfunc(
func, nin, nout, inunits=[u.km] * nin, ounits=[u.km] * nout)
self.ufunc_assume_registry[func] = quantity_frompyfunc(
func, nin, nout, inunits=[u.km] * nin, ounits=[u.km] * nout,
assume_correct_units=True)
# introspect for units
self.ufunc_introspect_registry[func] = quantity_frompyfunc(
func, nin, nout)
self.ufunc_introspect_assume_registry[func] = quantity_frompyfunc(
func, nin, nout, assume_correct_units=True)

@pytest.mark.parametrize(
"registry",
["ufunc_registry", "ufunc_assume_registry",
"ufunc_introspect_registry", "ufunc_introspect_assume_registry"])
@pytest.mark.parametrize("func, inp, res", ufunc_list)
def test_raw_ufunc(self, registry, func, inp, res):
"""In this case, the output will also not have units."""
got = getattr(self, registry)[func](*inp)
# need to convert from an object array to float array, for comparison
# also, for multiple output, need to make array, before type casting
got = np.array(got).astype(float) if isinstance(got, (np.ndarray, tuple)) else got
assert_allclose(got, res)

@pytest.mark.parametrize(
"registry",
["ufunc_registry", "ufunc_assume_registry",
"ufunc_introspect_registry", "ufunc_introspect_assume_registry"])
@pytest.mark.parametrize("func, inp, res", ufunc_list)
def test_no_units(self, registry, func, inp, res):
"""Test unitless input has unitless output. NO UNITS ATTACHED!"""
got = getattr(self, registry)[func](*inp) # no units
# need to convert from an object array to float array, for comparison
# also, for multiple output, need to make array, before type casting
got = np.array(got).astype(float) if isinstance(got, (np.ndarray, tuple)) else got
assert_allclose(got, res)

@pytest.mark.parametrize(
"registry",
["ufunc_registry", "ufunc_assume_registry"])
@pytest.mark.parametrize("func, inp, res", ufunc_list)
def test_has_units(self, registry, func, inp, res):
"""Test unitful input has unitful output."""
inp = [(x * u.km if x is not None else None) for x in inp]

if None in inp and "assume" not in registry:
with pytest.raises(u.UnitConversionError, match="not convertible"):
got = getattr(self, registry)[func](*inp)
return

got = getattr(self, registry)[func](*inp)
# need to convert from an object array to float array, for comparison
# also, for multiple output, need to make array, before type casting
got = (u.Quantity(got, dtype=float) if isinstance(got, (np.ndarray, tuple)) else got)
assert_quantity_allclose(got, res * u.km)

@pytest.mark.parametrize(
"registry",
["ufunc_introspect_registry", "ufunc_introspect_assume_registry"])
@pytest.mark.parametrize("func, inp, res", ufunc_list)
def test_has_introspected_units(self, registry, func, inp, res):
"""Test unitful input has unitful output."""
# give units to inputs
inp = [(x * u.km if x is not None else None) for x in inp]
# evaluate ufunc
got = getattr(self, registry)[func](*inp)

# check units & need to convert from an object array to float array,
# for comparison also, for multiple output, need to make array,
# before type casting
ra = inspect.signature(func).return_annotation
if isinstance(got, tuple):
assert all((g.unit == u.Unit(a) for g, a in zip(got, ra)
if isinstance(g, u.Quantity) and _is_ulike(a)))

got = [(u.Quantity(g, dtype=float).value if
isinstance(g, np.ndarray) else g)
for g in got]
elif u.utils._is_ulike(ra): # 1 unit-like output annotation
assert got.unit == u.Unit(ra)

got = u.Quantity(got, dtype=float).value
assert_allclose(got, res)

@pytest.mark.parametrize(
"registry",
["ufunc_registry", "ufunc_assume_registry"])
@pytest.mark.parametrize("func, inp, res", ufunc_list)
def test_has_wrong_units(self, registry, func, inp, res):
"""Test wrong unitful input raises errors."""
inp = [(x * u.deg if x is not None else None) for x in inp]

with pytest.raises(u.UnitConversionError, match="'deg'"):
got = getattr(self, registry)[func](*inp)
37 changes: 24 additions & 13 deletions astropy/units/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,18 +299,25 @@ def quantity_asanyarray(a, dtype=None):

# ------------------------------------------------------------------------------

def _is_ulike(unit):
"""Check if is unit-like."""
from astropy.units import Unit

try:
Unit(unit) # TODO! worry about structured units
except TypeError:
return False
else:
return True


def _is_seq_ulike(seq):
"""Check if a sequence is unit-like."""
from astropy.units import UnitBase

is_unit = isinstance(seq, UnitBase)
is_unit_sequence = (isinstance(seq, Sequence) and
all(isinstance(x, UnitBase) for x in seq))
return True if (is_unit or is_unit_sequence) else False
return isinstance(seq, Sequence) and all(_is_ulike(x) for x in seq)


def quantity_frompyfunc(func, nin, nout, inunits=None, ounits=None,
*, identity=None):
*, identity=None, assume_correct_units=False):
"""Quantity-aware `~numpy.frompyfunc`.
`~numpy.ufunc`s operate on only recognized `~numpy.dtype`s (e.g. float32),
Expand Down Expand Up @@ -350,27 +357,31 @@ def quantity_frompyfunc(func, nin, nout, inunits=None, ounits=None,

# -------------------------
# determine units by introspection
# and ensure seq[unit-like] -> seq[unit]

sig = inspect.signature(func)

# input units
if inunits is None:
svals = tuple(sig.parameters.values())
# TODO! more robust. what if no annotations?
svals = tuple(sig.parameters.values()) # sequence[Parameter]
inunits = [Unit(p.annotation) if _is_seq_ulike(p.annotation) else None
for p in svals]

# output units
if ounits is None:
ra = sig.return_annotation
if _is_seq_ulike(ra):
ounits = ra
ra = [ra] if _is_ulike(ra) else ra # now a sequence, if unit-like
ounits = ra if _is_seq_ulike(ra) else [None]

if ounits != [None] and len(ounits) != nout:
raise ValueError(
"function annotation is a sequence of unit-like, but "
f"its length ({len(ra)}) does not equal `nout` ({nout})")

# -------------------------
# make and register ufunc

ufunc = np.frompyfunc(func, nin, nout, identity=identity)
register_ufunc(ufunc, inunits=inunits, ounits=ounits)
register_ufunc(ufunc, inunits=inunits, ounits=ounits,
assume_correct_units=assume_correct_units)

return ufunc

0 comments on commit f1b9e51

Please sign in to comment.