Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add unit-aware quantity annotations. #10662

Merged
merged 7 commits into from
Oct 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 15 additions & 0 deletions astropy/units/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Support for ``typing`` py3.9+ features while min version is py3.8.
"""

try: # py 3.9+
from typing import Annotated
except (ImportError, ModuleNotFoundError): # optional dependency
try:
from typing_extensions import Annotated
except (ImportError, ModuleNotFoundError):

Annotated = NotImplemented

HAS_ANNOTATED = Annotated is not NotImplemented
119 changes: 89 additions & 30 deletions astropy/units/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,43 @@

__all__ = ['quantity_input']

import inspect
import typing as T
from numbers import Number
from collections.abc import Sequence
import inspect
from functools import wraps

import numpy as np

from astropy.utils.decorators import wraps
from .core import (Unit, UnitBase, UnitsError, add_enabled_equivalencies,
dimensionless_unscaled)
from ._typing import Annotated
from .core import (Unit, UnitBase, UnitsError,
add_enabled_equivalencies, dimensionless_unscaled)
from .function.core import FunctionUnitBase
from .physical import _unit_physical_mapping
from .physical import PhysicalType, get_physical_type
from .quantity import Quantity
from .structured import StructuredUnit


NoneType = type(None)


def _get_allowed_units(targets):
"""
From a list of target units (either as strings or unit objects) and physical
types, return a list of Unit objects.
"""

allowed_units = []
for target in targets:

try: # unit passed in as a string
target_unit = Unit(target)

except ValueError:

try: # See if the function writer specified a physical type
physical_type_id = _unit_physical_mapping[target]

except KeyError: # Function argument target is invalid
raise ValueError(f"Invalid unit or physical type '{target}'.")

# get unit directly from physical type id
target_unit = Unit._from_physical_type_id(physical_type_id)
try:
unit = Unit(target)
except (TypeError, ValueError):
try:
unit = get_physical_type(target)._unit
except (TypeError, ValueError, KeyError): # KeyError for Enum
raise ValueError(f"Invalid unit or physical type {target!r}.") from None

allowed_units.append(target_unit)
allowed_units.append(unit)

return allowed_units

Expand Down Expand Up @@ -77,8 +77,7 @@ def _validate_arg_value(param_name, func_name, arg, targets, equivalencies,

except AttributeError: # Either there is no .unit or no .is_equivalent
if hasattr(arg, "unit"):
error_msg = ("a 'unit' attribute without an 'is_equivalent' "
"method")
error_msg = ("a 'unit' attribute without an 'is_equivalent' method")
else:
error_msg = "no 'unit' attribute"

Expand All @@ -96,6 +95,45 @@ def _validate_arg_value(param_name, func_name, arg, targets, equivalencies,
raise UnitsError(f"{error_msg} '{str(targets[0])}'.")


def _parse_annotation(target):

if target in (None, NoneType, inspect._empty):
return target

# check if unit-like
try:
unit = Unit(target)
except (TypeError, ValueError):
try:
ptype = get_physical_type(target)
except (TypeError, ValueError, KeyError): # KeyError for Enum
if isinstance(target, str):
raise ValueError(f"invalid unit or physical type {target!r}.") from None
else:
return ptype
else:
return unit

# could be a type hint
origin = T.get_origin(target)
if origin is T.Union:
return [_parse_annotation(t) for t in T.get_args(target)]
elif origin is not Annotated: # can't be Quantity[]
return False

# parse type hint
cls, *annotations = T.get_args(target)
if not issubclass(cls, Quantity) or not annotations:
return False

# get unit from type hint
unit, *rest = annotations
if not isinstance(unit, (UnitBase, PhysicalType)):
return False

return unit


class QuantityInput:

@classmethod
Expand Down Expand Up @@ -144,6 +182,14 @@ def myfunction(myangle):
def myfunction(myangle: u.arcsec):
return myangle**2

Or using a unit-aware Quantity annotation.

.. code-block:: python

@u.quantity_input
def myfunction(myangle: u.Quantity[u.arcsec]):
return myangle**2

Also you can specify a return value annotation, which will
cause the function to always return a `~astropy.units.Quantity` in that
unit.
Expand Down Expand Up @@ -209,6 +255,9 @@ def wrapper(*func_args, **func_kwargs):
targets = param.annotation
is_annotation = True

# parses to unit if it's an annotation (or list thereof)
targets = _parse_annotation(targets)

# If the targets is empty, then no target units or physical
# types were specified so we can continue to the next arg
if targets is inspect.Parameter.empty:
Expand All @@ -229,7 +278,7 @@ def wrapper(*func_args, **func_kwargs):

# Check for None in the supplied list of allowed units and, if
# present and the passed value is also None, ignore.
elif None in targets:
elif None in targets or NoneType in targets:
if arg is None:
continue
else:
Expand All @@ -243,7 +292,7 @@ def wrapper(*func_args, **func_kwargs):
# non unit related annotations to pass through
if is_annotation:
valid_targets = [t for t in valid_targets
if isinstance(t, (str, UnitBase))]
if isinstance(t, (str, UnitBase, PhysicalType))]
mhvk marked this conversation as resolved.
Show resolved Hide resolved

# Now we loop over the allowed units/physical types and validate
# the value of the argument:
Expand All @@ -255,12 +304,22 @@ def wrapper(*func_args, **func_kwargs):
with add_enabled_equivalencies(self.equivalencies):
return_ = wrapped_function(*func_args, **func_kwargs)

valid_empty = (inspect.Signature.empty, None)
if (wrapped_signature.return_annotation not in valid_empty) and isinstance(
wrapped_signature.return_annotation, (str, UnitBase, FunctionUnitBase)):
return return_.to(wrapped_signature.return_annotation)
else:
return return_
# Return
ra = wrapped_signature.return_annotation
valid_empty = (inspect.Signature.empty, None, NoneType, T.NoReturn)
if ra not in valid_empty:
target = (ra if T.get_origin(ra) not in (Annotated, T.Union)
else _parse_annotation(ra))
if isinstance(target, str) or not isinstance(target, Sequence):
target = [target]
valid_targets = [t for t in target
if isinstance(t, (str, UnitBase, PhysicalType))]
nstarman marked this conversation as resolved.
Show resolved Hide resolved
_validate_arg_value("return", wrapped_function.__name__,
return_, valid_targets, self.equivalencies,
self.strict_dimensionless)
if len(valid_targets) > 0:
return_ <<= valid_targets[0]
return return_

return wrapper

Expand Down
92 changes: 92 additions & 0 deletions astropy/units/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
from .structured import StructuredUnit
from .utils import is_effectively_unity
from .format.latex import Latex
from astropy.utils.compat import NUMPY_LT_1_22
from astropy.utils.compat.misc import override__dir__
from astropy.utils.exceptions import AstropyDeprecationWarning, AstropyWarning
from astropy.utils.introspection import minversion
from astropy.utils.misc import isiterable
from astropy.utils.data_info import ParentDtypeInfo
from astropy import config as _config
Expand Down Expand Up @@ -314,6 +316,96 @@ class Quantity(np.ndarray):

__array_priority__ = 10000

def __class_getitem__(cls, unit_shape_dtype):
"""Quantity Type Hints.

Unit-aware type hints are ``Annotated`` objects that encode the class,
the unit, and possibly shape and dtype information, depending on the
python and :mod:`numpy` versions.

Schematically, ``Annotated[cls[shape, dtype], unit]``

As a classmethod, the type is the class, ie ``Quantity``
produces an ``Annotated[Quantity, ...]`` while a subclass
like :class:`~astropy.coordinates.Angle` returns
``Annotated[Angle, ...]``.

Parameters
----------
unit_shape_dtype : :class:`~astropy.units.UnitBase`, str, `~astropy.units.PhysicalType`, or tuple
Unit specification, can be the physical type (ie str or class).
If tuple, then the first element is the unit specification
and all other elements are for `numpy.ndarray` type annotations.
Whether they are included depends on the python and :mod:`numpy`
versions.

Returns
-------
`typing.Annotated`, `typing_extensions.Annotated`, `astropy.units.Unit`, or `astropy.units.PhysicalType`
Return type in this preference order:
* if python v3.9+ : `typing.Annotated`
* if :mod:`typing_extensions` is installed : `typing_extensions.Annotated`
* `astropy.units.Unit` or `astropy.units.PhysicalType`

Raises
------
TypeError
If the unit/physical_type annotation is not Unit-like or
PhysicalType-like.

Examples
--------
Create a unit-aware Quantity type annotation

>>> Quantity[Unit("s")]
Annotated[Quantity, Unit("s")]

See Also
--------
`~astropy.units.quantity_input`
Use annotations for unit checks on function arguments and results.

Notes
-----
With Python 3.9+ or :mod:`typing_extensions`, |Quantity| types are also
static-type compatible.
"""
# LOCAL
from ._typing import HAS_ANNOTATED, Annotated

# process whether [unit] or [unit, shape, ptype]
if isinstance(unit_shape_dtype, tuple): # unit, shape, dtype
target = unit_shape_dtype[0]
shape_dtype = unit_shape_dtype[1:]
else: # just unit
target = unit_shape_dtype
shape_dtype = ()

# Allowed unit/physical types. Errors if neither.
try:
unit = Unit(target)
except (TypeError, ValueError):
from astropy.units.physical import get_physical_type

try:
unit = get_physical_type(target)
except (TypeError, ValueError, KeyError): # KeyError for Enum
raise TypeError("unit annotation is not a Unit or PhysicalType") from None

# Allow to sort of work for python 3.8- / no typing_extensions
# instead of bailing out, return the unit for `quantity_input`
if not HAS_ANNOTATED:
warnings.warn("Quantity annotations are valid static type annotations only"
" if Python is v3.9+ or `typing_extensions` is installed.")
return unit

# Quantity does not (yet) properly extend the NumPy generics types,
# introduced in numpy v1.22+, instead just including the unit info as
# metadata using Annotated.
if not NUMPY_LT_1_22:
cls = super().__class_getitem__((cls, *shape_dtype))
return Annotated.__class_getitem__((cls, unit))

def __new__(cls, value, unit=None, dtype=None, copy=True, order=None,
subok=False, ndmin=0):

Expand Down