Skip to content

Commit

Permalink
Merge pull request #10662 from nstarman/unit-aware_quantity_annotations
Browse files Browse the repository at this point in the history
add unit-aware quantity annotations.
  • Loading branch information
nstarman committed Oct 29, 2021
2 parents 300565f + e6d147f commit 73541a2
Show file tree
Hide file tree
Showing 12 changed files with 559 additions and 66 deletions.
15 changes: 15 additions & 0 deletions astropy/units/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Support for ``typing`` py3.9+ features while min version is py3.8.
"""

try: # py 3.9+
from typing import Annotated
except (ImportError, ModuleNotFoundError): # optional dependency
try:
from typing_extensions import Annotated
except (ImportError, ModuleNotFoundError):

Annotated = NotImplemented

HAS_ANNOTATED = Annotated is not NotImplemented
119 changes: 89 additions & 30 deletions astropy/units/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,43 @@

__all__ = ['quantity_input']

import inspect
import typing as T
from numbers import Number
from collections.abc import Sequence
import inspect
from functools import wraps

import numpy as np

from astropy.utils.decorators import wraps
from .core import (Unit, UnitBase, UnitsError, add_enabled_equivalencies,
dimensionless_unscaled)
from ._typing import Annotated
from .core import (Unit, UnitBase, UnitsError,
add_enabled_equivalencies, dimensionless_unscaled)
from .function.core import FunctionUnitBase
from .physical import _unit_physical_mapping
from .physical import PhysicalType, get_physical_type
from .quantity import Quantity
from .structured import StructuredUnit


NoneType = type(None)


def _get_allowed_units(targets):
"""
From a list of target units (either as strings or unit objects) and physical
types, return a list of Unit objects.
"""

allowed_units = []
for target in targets:

try: # unit passed in as a string
target_unit = Unit(target)

except ValueError:

try: # See if the function writer specified a physical type
physical_type_id = _unit_physical_mapping[target]

except KeyError: # Function argument target is invalid
raise ValueError(f"Invalid unit or physical type '{target}'.")

# get unit directly from physical type id
target_unit = Unit._from_physical_type_id(physical_type_id)
try:
unit = Unit(target)
except (TypeError, ValueError):
try:
unit = get_physical_type(target)._unit
except (TypeError, ValueError, KeyError): # KeyError for Enum
raise ValueError(f"Invalid unit or physical type {target!r}.") from None

allowed_units.append(target_unit)
allowed_units.append(unit)

return allowed_units

Expand Down Expand Up @@ -77,8 +77,7 @@ def _validate_arg_value(param_name, func_name, arg, targets, equivalencies,

except AttributeError: # Either there is no .unit or no .is_equivalent
if hasattr(arg, "unit"):
error_msg = ("a 'unit' attribute without an 'is_equivalent' "
"method")
error_msg = ("a 'unit' attribute without an 'is_equivalent' method")
else:
error_msg = "no 'unit' attribute"

Expand All @@ -96,6 +95,45 @@ def _validate_arg_value(param_name, func_name, arg, targets, equivalencies,
raise UnitsError(f"{error_msg} '{str(targets[0])}'.")


def _parse_annotation(target):

if target in (None, NoneType, inspect._empty):
return target

# check if unit-like
try:
unit = Unit(target)
except (TypeError, ValueError):
try:
ptype = get_physical_type(target)
except (TypeError, ValueError, KeyError): # KeyError for Enum
if isinstance(target, str):
raise ValueError(f"invalid unit or physical type {target!r}.") from None
else:
return ptype
else:
return unit

# could be a type hint
origin = T.get_origin(target)
if origin is T.Union:
return [_parse_annotation(t) for t in T.get_args(target)]
elif origin is not Annotated: # can't be Quantity[]
return False

# parse type hint
cls, *annotations = T.get_args(target)
if not issubclass(cls, Quantity) or not annotations:
return False

# get unit from type hint
unit, *rest = annotations
if not isinstance(unit, (UnitBase, PhysicalType)):
return False

return unit


class QuantityInput:

@classmethod
Expand Down Expand Up @@ -144,6 +182,14 @@ def myfunction(myangle):
def myfunction(myangle: u.arcsec):
return myangle**2
Or using a unit-aware Quantity annotation.
.. code-block:: python
@u.quantity_input
def myfunction(myangle: u.Quantity[u.arcsec]):
return myangle**2
Also you can specify a return value annotation, which will
cause the function to always return a `~astropy.units.Quantity` in that
unit.
Expand Down Expand Up @@ -209,6 +255,9 @@ def wrapper(*func_args, **func_kwargs):
targets = param.annotation
is_annotation = True

# parses to unit if it's an annotation (or list thereof)
targets = _parse_annotation(targets)

# If the targets is empty, then no target units or physical
# types were specified so we can continue to the next arg
if targets is inspect.Parameter.empty:
Expand All @@ -229,7 +278,7 @@ def wrapper(*func_args, **func_kwargs):

# Check for None in the supplied list of allowed units and, if
# present and the passed value is also None, ignore.
elif None in targets:
elif None in targets or NoneType in targets:
if arg is None:
continue
else:
Expand All @@ -243,7 +292,7 @@ def wrapper(*func_args, **func_kwargs):
# non unit related annotations to pass through
if is_annotation:
valid_targets = [t for t in valid_targets
if isinstance(t, (str, UnitBase))]
if isinstance(t, (str, UnitBase, PhysicalType))]

# Now we loop over the allowed units/physical types and validate
# the value of the argument:
Expand All @@ -255,12 +304,22 @@ def wrapper(*func_args, **func_kwargs):
with add_enabled_equivalencies(self.equivalencies):
return_ = wrapped_function(*func_args, **func_kwargs)

valid_empty = (inspect.Signature.empty, None)
if (wrapped_signature.return_annotation not in valid_empty) and isinstance(
wrapped_signature.return_annotation, (str, UnitBase, FunctionUnitBase)):
return return_.to(wrapped_signature.return_annotation)
else:
return return_
# Return
ra = wrapped_signature.return_annotation
valid_empty = (inspect.Signature.empty, None, NoneType, T.NoReturn)
if ra not in valid_empty:
target = (ra if T.get_origin(ra) not in (Annotated, T.Union)
else _parse_annotation(ra))
if isinstance(target, str) or not isinstance(target, Sequence):
target = [target]
valid_targets = [t for t in target
if isinstance(t, (str, UnitBase, PhysicalType))]
_validate_arg_value("return", wrapped_function.__name__,
return_, valid_targets, self.equivalencies,
self.strict_dimensionless)
if len(valid_targets) > 0:
return_ <<= valid_targets[0]
return return_

return wrapper

Expand Down
92 changes: 92 additions & 0 deletions astropy/units/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
from .structured import StructuredUnit
from .utils import is_effectively_unity
from .format.latex import Latex
from astropy.utils.compat import NUMPY_LT_1_22
from astropy.utils.compat.misc import override__dir__
from astropy.utils.exceptions import AstropyDeprecationWarning, AstropyWarning
from astropy.utils.introspection import minversion
from astropy.utils.misc import isiterable
from astropy.utils.data_info import ParentDtypeInfo
from astropy import config as _config
Expand Down Expand Up @@ -314,6 +316,96 @@ class Quantity(np.ndarray):

__array_priority__ = 10000

def __class_getitem__(cls, unit_shape_dtype):
"""Quantity Type Hints.
Unit-aware type hints are ``Annotated`` objects that encode the class,
the unit, and possibly shape and dtype information, depending on the
python and :mod:`numpy` versions.
Schematically, ``Annotated[cls[shape, dtype], unit]``
As a classmethod, the type is the class, ie ``Quantity``
produces an ``Annotated[Quantity, ...]`` while a subclass
like :class:`~astropy.coordinates.Angle` returns
``Annotated[Angle, ...]``.
Parameters
----------
unit_shape_dtype : :class:`~astropy.units.UnitBase`, str, `~astropy.units.PhysicalType`, or tuple
Unit specification, can be the physical type (ie str or class).
If tuple, then the first element is the unit specification
and all other elements are for `numpy.ndarray` type annotations.
Whether they are included depends on the python and :mod:`numpy`
versions.
Returns
-------
`typing.Annotated`, `typing_extensions.Annotated`, `astropy.units.Unit`, or `astropy.units.PhysicalType`
Return type in this preference order:
* if python v3.9+ : `typing.Annotated`
* if :mod:`typing_extensions` is installed : `typing_extensions.Annotated`
* `astropy.units.Unit` or `astropy.units.PhysicalType`
Raises
------
TypeError
If the unit/physical_type annotation is not Unit-like or
PhysicalType-like.
Examples
--------
Create a unit-aware Quantity type annotation
>>> Quantity[Unit("s")]
Annotated[Quantity, Unit("s")]
See Also
--------
`~astropy.units.quantity_input`
Use annotations for unit checks on function arguments and results.
Notes
-----
With Python 3.9+ or :mod:`typing_extensions`, |Quantity| types are also
static-type compatible.
"""
# LOCAL
from ._typing import HAS_ANNOTATED, Annotated

# process whether [unit] or [unit, shape, ptype]
if isinstance(unit_shape_dtype, tuple): # unit, shape, dtype
target = unit_shape_dtype[0]
shape_dtype = unit_shape_dtype[1:]
else: # just unit
target = unit_shape_dtype
shape_dtype = ()

# Allowed unit/physical types. Errors if neither.
try:
unit = Unit(target)
except (TypeError, ValueError):
from astropy.units.physical import get_physical_type

try:
unit = get_physical_type(target)
except (TypeError, ValueError, KeyError): # KeyError for Enum
raise TypeError("unit annotation is not a Unit or PhysicalType") from None

# Allow to sort of work for python 3.8- / no typing_extensions
# instead of bailing out, return the unit for `quantity_input`
if not HAS_ANNOTATED:
warnings.warn("Quantity annotations are valid static type annotations only"
" if Python is v3.9+ or `typing_extensions` is installed.")
return unit

# Quantity does not (yet) properly extend the NumPy generics types,
# introduced in numpy v1.22+, instead just including the unit info as
# metadata using Annotated.
if not NUMPY_LT_1_22:
cls = super().__class_getitem__((cls, *shape_dtype))
return Annotated.__class_getitem__((cls, unit))

def __new__(cls, value, unit=None, dtype=None, copy=True, order=None,
subok=False, ndmin=0):

Expand Down

0 comments on commit 73541a2

Please sign in to comment.