Skip to content

Commit

Permalink
add ufunc to registry list
Browse files Browse the repository at this point in the history
Signed-off-by: Nathaniel Starkman (@nstarman) <nstarkman@protonmail.com>
  • Loading branch information
nstarman committed Jul 21, 2021
1 parent dc3009a commit 39ad355
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 21 deletions.
6 changes: 5 additions & 1 deletion astropy/units/quantity_helper/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ def helper_clip(f, unit1, unit2, unit3):


# HELPER NARGS
REGISTERD_NARG_UFUNCS = set()


def _is_ulike(unit, allow_structured=False): # TODO! make a public-scoped function for this
"""Check if is unit-like."""
from astropy.units import Unit
Expand Down Expand Up @@ -356,7 +359,7 @@ def register_ufunc(ufunc, inunits, ounits, *, assume_correct_units=False):
Parameters
----------
ufunc : `~numpy.ufunc`
inunits, ounits : unit-like or sequence thereof
inunits, ounits : unit-like or sequence thereof or None
Sequence of the correct input and output units, respectively.
.. warning::
Expand Down Expand Up @@ -482,6 +485,7 @@ def helper_nargs(f, *units):
return converters, ounits

UFUNC_HELPERS[ufunc] = helper_nargs
REGISTERD_NARG_UFUNCS.add(ufunc)


# list of ufuncs:
Expand Down
24 changes: 14 additions & 10 deletions astropy/units/tests/test_quantity_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from astropy import units as u
from astropy.tests.helper import assert_quantity_allclose
from astropy.units import quantity_helper as qh
from astropy.units.utils import quantity_frompyfunc
from astropy.units.utils import frompyfunc
from astropy.units.quantity_helper.converters import UfuncHelpers
from astropy.units.quantity_helper.helpers import (
helper_sqrt, register_ufunc, _is_ulike)
Expand Down Expand Up @@ -67,12 +67,15 @@ def test_coverage(self):
set(qh.UFUNC_HELPERS.keys()))
# Check that every numpy ufunc is covered.
assert all_np_ufuncs - all_q_ufuncs == set()
# Check that all ufuncs we cover come from numpy or erfa.

# Check that all ufuncs we cover come from numpy or erfa or are
# registered with ``register_ufunc``
# (Since coverage for erfa is incomplete, we do not check
# this the other way).
all_erfa_ufuncs = set([ufunc for ufunc in erfa_ufunc.__dict__.values()
if isinstance(ufunc, np.ufunc)])
assert (all_q_ufuncs - all_np_ufuncs - all_erfa_ufuncs == set())
assert ((all_q_ufuncs - all_np_ufuncs - all_erfa_ufuncs
- qh.REGISTERD_NARG_UFUNCS) == set())

def test_scipy_registered(self):
# Should be registered as existing even if scipy is not available.
Expand Down Expand Up @@ -1499,6 +1502,7 @@ def badfunc(x):


class TestQuantityFromPyFunc:
"""Test `astropy.units.utils.frompyfunc`."""

def setup_class(self):
# registry of ufuncs
Expand All @@ -1509,15 +1513,15 @@ def setup_class(self):
for func in ufunc_list[:, 0]:
nin, nout = map(int, func.__name__.split("_")[1:])

self.ufunc_registry[func] = quantity_frompyfunc(
self.ufunc_registry[func] = frompyfunc(
func, nin, nout, inunits=[u.km] * nin, ounits=[u.km] * nout)
self.ufunc_assume_registry[func] = quantity_frompyfunc(
self.ufunc_assume_registry[func] = frompyfunc(
func, nin, nout, inunits=[u.km] * nin, ounits=[u.km] * nout,
assume_correct_units=True)
# introspect for units
self.ufunc_introspect_registry[func] = quantity_frompyfunc(
self.ufunc_introspect_registry[func] = frompyfunc(
func, nin, nout)
self.ufunc_introspect_assume_registry[func] = quantity_frompyfunc(
self.ufunc_introspect_assume_registry[func] = frompyfunc(
func, nin, nout, assume_correct_units=True)

# -------------------
Expand Down Expand Up @@ -1603,14 +1607,14 @@ def func(x: "km", y) -> ("km", "km"):
pass

with pytest.raises(ValueError, match="not equal `nout`"):
quantity_frompyfunc(func, 2, 1)
frompyfunc(func, 2, 1)

def test_returns_quantity_object_array(self):
"""Test when func returns a Quantity."""
def badfunc(x: u.Celsius):
return x << u.km

badufunc = quantity_frompyfunc(badfunc, 1, 1)
badufunc = frompyfunc(badfunc, 1, 1)
assert badufunc([0, 10, 20] * u.Celsius).dtype == object
assert badufunc([0, 10, 20] * u.Celsius)[0].unit == u.km

Expand All @@ -1619,5 +1623,5 @@ def test_dropin_replacement_for_frompyfunc(self):
def func(x):
return x

ufunc = quantity_frompyfunc(func, 1, 1)
ufunc = frompyfunc(func, 1, 1)
assert all(ufunc([0, 10, 20]) == np.array([0., 10., 20.], dtype=object))
20 changes: 10 additions & 10 deletions astropy/units/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ def quantity_asanyarray(a, dtype=None):
# ------------------------------------------------------------------------------


def quantity_frompyfunc(func, nin, nout, inunits=None, ounits=None,
*, identity=None, assume_correct_units=False):
"""Quantity-aware `~numpy.frompyfunc`.
def frompyfunc(func, nin, nout, inunits=None, ounits=None,
*, identity=None, assume_correct_units=False):
"""Quantity-aware `numpy.frompyfunc`. Works as a drop-in replacement.
`~numpy.ufunc`s operate on only recognized `~numpy.dtype`s (e.g. float32),
so units must be removed beforehand and replaced afterwards. Therefore units
Expand All @@ -313,7 +313,7 @@ def quantity_frompyfunc(func, nin, nout, inunits=None, ounits=None,
func : callable
nin, nout : int
Number of ufunc's inputs and outputs.
inunits, ounits : unit-like or sequence thereof (optional)
inunits, ounits : unit-like or sequence thereof or None (optional)
Sequence of the input and output units, respectively.
.. warning::
Expand Down Expand Up @@ -366,7 +366,7 @@ def quantity_frompyfunc(func, nin, nout, inunits=None, ounits=None,
units) and then the output units will be assigned.
``c2f`` will work on Quantities, but pretending it didn't...
>>> ufunc = quantity_frompyfunc(c2f, nin=1, nout=1,
>>> ufunc = frompyfunc(c2f, nin=1, nout=1,
... inunits=u.Celsius, ounits=Fahrenheit)
>>> ufunc
<ufunc 'c2f (vectorized)'>
Expand All @@ -393,32 +393,32 @@ def quantity_frompyfunc(func, nin, nout, inunits=None, ounits=None,
of Quantity will be returned instead of a Quantity array.
>>> def badc2f(x): return (9./5. * x + 32) << Fahrenheit
>>> badufunc = quantity_frompyfunc(badc2f, 1, 1, u.Celsius, Fahrenheit)
>>> badufunc = frompyfunc(badc2f, 1, 1, u.Celsius, Fahrenheit)
>>> badufunc([0, 10, 20] * u.Celsius)
<Quantity [<Quantity 32. deg_F>, <Quantity 50. deg_F>,
<Quantity 68. deg_F>] deg_F>
**Extra features**:
As a convenience, ``quantity_frompyfunc`` can also introspect function
As a convenience, ``frompyfunc`` can also introspect function
annotations and use these to determine the input and output units,
obviating the need for arguments ``inunits`` and ``ounits``.
>>> def c2f(x: u.Celsius) -> Fahrenheit: return 9./5. * x + 32
>>> ufunc = quantity_frompyfunc(c2f, 1, 1)
>>> ufunc = frompyfunc(c2f, 1, 1)
>>> ufunc(-40 * u.Celsius)
<Quantity -40. deg_F>
When a ufunc has at least 2 inputs, if one of the arguments does not have
units it is assumed to be `~astropy.units.dimensionless_unscaled`. However,
``quantity_frompyfunc`` takes the keyword argument "assume_correct_units",
``frompyfunc`` takes the keyword argument "assume_correct_units",
in which case the ufunc will instead interpret a unitless argument as
having units 'inunits' -- i.e. the correct units.
>>> def exf(x: u.km, y: u.s) -> u.km**2/u.s: return x ** 2 / y
>>> exufunc = quantity_frompyfunc(exf, 2, 1, assume_correct_units=True)
>>> exufunc = frompyfunc(exf, 2, 1, assume_correct_units=True)
>>> exufunc(3 * u.km, 2)
<Quantity 4.5 km2 / s>
Expand Down

0 comments on commit 39ad355

Please sign in to comment.