Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

quantity-aware frompyfunc #11893

Closed
wants to merge 13 commits into from
1 change: 1 addition & 0 deletions astropy/units/quantity_helper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# By importing helpers, all the unit conversion functions needed for
# numpy ufuncs and functions are defined.
from . import helpers, function_helpers
from .helpers import REGISTERED_NARG_UFUNCS
# For scipy.special and erfa, importing the helper modules ensures
# the definitions are added as modules to UFUNC_HELPERS, to be loaded
# on demand.
Expand Down
174 changes: 174 additions & 0 deletions astropy/units/quantity_helper/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
units for a given ufunc, given input units.
"""

from collections.abc import Sequence
from fractions import Fraction

import numpy as np
Expand All @@ -17,6 +18,7 @@
UnitsError, UnitConversionError, UnitTypeError,
dimensionless_unscaled, get_current_unit_registry,
unit_scale_converter)
from astropy.utils import isiterable


def _d(unit):
Expand Down Expand Up @@ -320,6 +322,178 @@ def helper_clip(f, unit1, unit2, unit3):
return converters, result_unit


# HELPER NARGS
REGISTERED_NARG_UFUNCS = set()


def _is_ulike(unit, allow_structured=False): # TODO! make a public-scoped function for this
"""Check if is unit-like."""
from astropy.units import Unit, StructuredUnit

try:
unit = Unit(unit)
except TypeError:
return False

if allow_structured:
ulike = True
elif isinstance(unit, StructuredUnit):
ulike = False
else:
ulike = True

return ulike


def _is_seq_ulike(seq, allow_structured=False):
"""Check if a sequence is unit-like."""
return (isinstance(seq, Sequence)
and all(_is_ulike(x, allow_structured) for x in seq))


def register_ufunc(ufunc, inunits, ounits, *, assume_correct_units=False):
"""
Register `~numpy.ufunc` in ``UFUNC_HELPERS``, along with the conversion
functions necessary to strip input units and assign output units. ufuncs
operate on only recognized `~numpy.dtype`s (e.g. int, float32), so units
must be removed beforehand and replaced afterwards. Therefore units MUST
BE KNOWN a priori, as they will not be propagated by the astropy machinery.

Note that ufuncs always return an object array.

Parameters
----------
ufunc : `~numpy.ufunc`
inunits, ounits : unit-like or sequence thereof or None
Sequence of the correct input and output units, respectively.

.. warning::
Inputs will be converted to these units before being passed to the
returned `~numpy.ufunc`. Outputs will be assigned these units.
Make sure these are the correct units.

assume_correct_units : bool, optional
When input arrays are given without units, but the ufunc has 'inunits',
whether the array is assumed to have dimensionless units (default) or
have 'inunits'.

Examples
--------
We first need to import relevant packages:

>>> import numpy as np
>>> import astropy.units as u
>>> from astropy.units.imperial import Fahrenheit

Now we can define a python function. For this example we will define the
conversion between Celsius and Fahrenheit.

>>> def c2f(x): return 9./5. * x + 32

With numpy this function can be turned into a `numpy.ufunc`. This is useful
if the python function works only on scalars, but we want to be able to
pass in arrays. ``c2f`` will work on arrays, but pretending it didn't...

>>> ufunc = np.frompyfunc(c2f, 1, 1)
>>> ufunc # doctest: +SKIP
<ufunc 'c2f (vectorized)'>

One of the limitations of a `numpy.ufunc` is that it cannot work with
`astropy.units.Quantity`. This is a partially solved problem as numpy
allows for `numpy.ufunc` evaluation to be overridden. We register this
``ufunc`` and provide the input and output units. The internal calculation
will be done on the unitless arrays (by converting to the input units)
and then the output units will be assigned.

>>> register_ufunc(ufunc, [u.Celsius], [Fahrenheit])

>>> ufunc(36 * u.Celsius)
<Quantity 96.8 deg_F>
>>> ufunc([0, 10, 20] * u.Celsius)
<Quantity [32.0, 50.0, 68.0] deg_F>


**There are two caveats to note**:

1. The `numpy.ufunc` overrides only work when at least one argument
is a `~astropy.units.Quantity`. In the above example ``c2f`` takes only
one argument, so if a scalar or `~numpy.ndarray` were passed instead of
a Quantity, the output will also be an ndarray.

>>> ufunc(36)
96.8
>>> ufunc(np.array([0, 10, 20])) # note return dtype is an object
array([32.0, 50.0, 68.0], dtype=object)

2. The function cannot return a Quantity with units. If so, an object array
of Quantity will be returned instead of a Quantity array.

>>> def badc2f(x): return (9./5. * x + 32) << Fahrenheit
>>> badufunc = np.frompyfunc(badc2f, 1, 1)
>>> register_ufunc(badufunc, [u.Celsius], [Fahrenheit])
>>> badufunc([0, 10, 20] * u.Celsius)
<Quantity [<Quantity 32. deg_F>, <Quantity 50. deg_F>,
<Quantity 68. deg_F>] deg_F>


**Extra feature**:

When a ufunc has at least 2 inputs, if one of the arguments does not have
units it is assumed to be `~astropy.units.dimensionless_unscaled`. However,
``register_ufunc`` takes the keyword argument "assume_correct_units",
in which case the ufunc will instead interpret a unitless argument as
having units 'inunits' -- i.e. the correct units.

>>> def exfunc(x: u.km, y: u.s) -> u.km**2/u.s: return x ** 2 / y
>>> exufunc = np.frompyfunc(exfunc, 2, 1)
>>> register_ufunc(exufunc, [u.km, u.s], ["km2/s"],
... assume_correct_units=True)
>>> exufunc(3 * u.km, 2) # arg 2 is unitless. 'inunits' are assumed.
<Quantity 4.5 km2 / s>

"""
from astropy.units import Unit, dimensionless_unscaled

# process sequence[unit-like] -> sequence[unit]
inunits = [inunits] if (inunits is None or _is_ulike(inunits)) else inunits
inunits = [(Unit(iu) if iu is not None else None) for iu in inunits]

# process sequence[unit-like] -> sequence[unit]
ounits = [ounits] if (ounits is None or _is_ulike(ounits)) else ounits
ounits = [(Unit(ou) if ou is not None else None) for ou in ounits]
ounits = ounits[0] if len(ounits) == 1 else ounits

# backup units for interpreting array (no units) input
fallbackinunits = (inunits if assume_correct_units
else [dimensionless_unscaled] * len(inunits))

def helper_nargs(f, *units):
"""Helper function to convert input units and assign output units.

Parameters
----------
f : callable
*units : `~astropy.units.UnitBase` or None
The units of the inputs. If None (a unitless input) the unit is
assumed to be the one specified in
`~astropy.units.quantity_helper.helpers.register_ufunc`

Returns
-------
converters : sequence[callable]
ounits : sequence[unit-like]

"""
# unit converters. No units = assumed to be in fallback units.
# if None in 'inunits', skip conversion
converters = [(get_converter(frm or fb, to) if to is not None else None)
for frm, to, fb in zip(units, inunits, fallbackinunits)]
return converters, ounits

UFUNC_HELPERS[ufunc] = helper_nargs
REGISTERED_NARG_UFUNCS.add(ufunc)


# list of ufuncs:
# https://numpy.org/doc/stable/reference/ufuncs.html#available-ufuncs

Expand Down