Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: A few updates to the array_api #20066

Merged
merged 23 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
48005c1
Allow casting in the array API asarray()
asmeurer Oct 1, 2021
9af4814
Restrict multidimensional indexing in the array API namespace
asmeurer Oct 7, 2021
cb0b9c6
Fix type promotion for numpy.array_api.where
asmeurer Oct 19, 2021
7da7a99
Print empty array_api arrays using empty()
asmeurer Oct 20, 2021
0169af7
Fix an incorrect slice bounds guard in the array API
asmeurer Oct 22, 2021
e0589fd
Disallow multiple different dtypes in the input to np.array_api.meshgrid
asmeurer Nov 5, 2021
21faf34
Remove DLPack support from numpy.array_api.asarray()
asmeurer Nov 5, 2021
d0f591d
Remove __len__ from the array API array object
asmeurer Nov 5, 2021
17d7886
Add astype() to numpy.array_api
asmeurer Nov 8, 2021
7d9edf3
Update the unique_* functions in numpy.array_api
asmeurer Nov 8, 2021
cb335d2
Add the stream argument to the array API to_device method
asmeurer Nov 8, 2021
5cae94d
Use the NamedTuple classes for the type signatures
asmeurer Nov 8, 2021
f6053aa
Add unique_counts to the array API namespace
asmeurer Nov 8, 2021
680e0a4
Remove some unused imports
asmeurer Nov 8, 2021
49a3cc9
Update the array_api indexing restrictions
asmeurer Nov 9, 2021
475b01d
Merge branch 'main' into array-api-updates2
asmeurer Nov 9, 2021
6352d11
Use a simpler type annotation for the array API to_device method
asmeurer Nov 10, 2021
b7bf06e
Fix a typo
asmeurer Nov 10, 2021
61bc679
Fix a test failure in the array_api submodule
asmeurer Nov 10, 2021
f72ed85
Merge branch 'array-api-updates2' of github.com:asmeurer/numpy into a…
asmeurer Nov 10, 2021
09aa2b5
Merge branch 'main' into array-api-updates2
asmeurer Nov 10, 2021
580a616
Add dlpack support to the array_api submodule
asmeurer Nov 11, 2021
b2af8e9
Merge branch 'main' into array-api-updates2
rgommers Nov 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 6 additions & 4 deletions numpy/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
empty,
empty_like,
eye,
_from_dlpack,
from_dlpack,
full,
full_like,
linspace,
Expand All @@ -155,7 +155,7 @@
"empty",
"empty_like",
"eye",
"_from_dlpack",
"from_dlpack",
"full",
"full_like",
"linspace",
Expand All @@ -169,6 +169,7 @@
]

from ._data_type_functions import (
astype,
broadcast_arrays,
broadcast_to,
can_cast,
Expand All @@ -178,6 +179,7 @@
)

__all__ += [
"astype",
"broadcast_arrays",
"broadcast_to",
"can_cast",
Expand Down Expand Up @@ -358,9 +360,9 @@

__all__ += ["argmax", "argmin", "nonzero", "where"]

from ._set_functions import unique
from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values

__all__ += ["unique"]
__all__ += ["unique_all", "unique_counts", "unique_inverse", "unique_values"]

from ._sorting_functions import argsort, sort

Expand Down
48 changes: 31 additions & 17 deletions numpy/array_api/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from typing import TYPE_CHECKING, Optional, Tuple, Union, Any

if TYPE_CHECKING:
from ._typing import PyCapsule, Device, Dtype
from ._typing import Any, PyCapsule, Device, Dtype

import numpy as np

Expand Down Expand Up @@ -99,9 +99,13 @@ def __repr__(self: Array, /) -> str:
"""
Performs the operation __repr__.
"""
prefix = "Array("
suffix = f", dtype={self.dtype.name})"
mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
if 0 in self.shape:
prefix = "empty("
mid = str(self.shape)
else:
prefix = "Array("
mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
return prefix + mid + suffix

# These are various helper functions to make the array behavior match the
Expand Down Expand Up @@ -244,6 +248,10 @@ def _validate_index(key, shape):
The following cases are allowed by NumPy, but not specified by the array
API specification:

- Indices to not include an implicit ellipsis at the end. That is,
every axis of an array must be explicitly indexed or an ellipsis
included.

- The start and stop of a slice may not be out of bounds. In
particular, for a slice ``i:j:k`` on an axis of size ``n``, only the
following are allowed:
Expand All @@ -270,14 +278,18 @@ def _validate_index(key, shape):
return key
if shape == ():
return key
if len(shape) > 1:
raise IndexError(
"Multidimensional arrays must include an index for every axis or use an ellipsis"
)
size = shape[0]
# Ensure invalid slice entries are passed through.
if key.start is not None:
try:
operator.index(key.start)
except TypeError:
return key
if not (-size <= key.start <= max(0, size - 1)):
if not (-size <= key.start <= size):
raise IndexError(
"Slices with out-of-bounds start are not allowed in the array API namespace"
)
Expand Down Expand Up @@ -322,6 +334,10 @@ def _validate_index(key, shape):
zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])
):
Array._validate_index(idx, (size,))
if n_ellipsis == 0 and len(key) < len(shape):
raise IndexError(
"Multidimensional arrays must include an index for every axis or use an ellipsis"
)
return key
elif isinstance(key, bool):
return key
Expand All @@ -339,7 +355,12 @@ def _validate_index(key, shape):
"newaxis indices are not allowed in the array API namespace"
)
try:
return operator.index(key)
key = operator.index(key)
if shape is not None and len(shape) > 1:
raise IndexError(
"Multidimensional arrays must include an index for every axis or use an ellipsis"
)
return key
except TypeError:
# Note: This also omits boolean arrays that are not already in
# Array() form, like a list of booleans.
Expand Down Expand Up @@ -403,16 +424,14 @@ def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule:
"""
Performs the operation __dlpack__.
"""
res = self._array.__dlpack__(stream=stream)
return self.__class__._new(res)
return self._array.__dlpack__(stream=stream)

def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
"""
Performs the operation __dlpack_device__.
"""
# Note: device support is required for this
res = self._array.__dlpack_device__()
return self.__class__._new(res)
return self._array.__dlpack_device__()

def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
"""
Expand Down Expand Up @@ -527,13 +546,6 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
res = self._array.__le__(other._array)
return self.__class__._new(res)

# Note: __len__ may end up being removed from the array API spec.
def __len__(self, /) -> int:
"""
Performs the operation __len__.
"""
return self._array.__len__()

def __lshift__(self: Array, other: Union[int, Array], /) -> Array:
"""
Performs the operation __lshift__.
Expand Down Expand Up @@ -995,7 +1007,9 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array:
res = self._array.__rxor__(other._array)
return self.__class__._new(res)

def to_device(self: Array, device: Device, /) -> Array:
def to_device(self: Array, device: Device, /, stream: None = None) -> Array:
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
if device == 'cpu':
return self
raise ValueError(f"Unsupported device {device!r}")
Expand Down
19 changes: 13 additions & 6 deletions numpy/array_api/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Device,
Dtype,
NestedSequence,
SupportsDLPack,
SupportsBufferProtocol,
)
from collections.abc import Sequence
Expand All @@ -36,7 +35,6 @@ def asarray(
int,
float,
NestedSequence[bool | int | float],
SupportsDLPack,
SupportsBufferProtocol,
],
/,
Expand All @@ -60,7 +58,9 @@ def asarray(
if copy is False:
# Note: copy=False is not yet implemented in np.asarray
raise NotImplementedError("copy=False is not yet implemented")
if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype):
if isinstance(obj, Array):
if dtype is not None and obj.dtype != dtype:
copy = True
if copy is True:
return Array._new(np.array(obj._array, copy=True, dtype=dtype))
return obj
Expand Down Expand Up @@ -151,9 +151,10 @@ def eye(
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))


def _from_dlpack(x: object, /) -> Array:
# Note: dlpack support is not yet implemented on Array
raise NotImplementedError("DLPack support is not yet implemented")
def from_dlpack(x: object, /) -> Array:
from ._array_object import Array

return Array._new(np._from_dlpack(x))


def full(
Expand Down Expand Up @@ -240,6 +241,12 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
"""
from ._array_object import Array

# Note: unlike np.meshgrid, only inputs with all the same dtype are
# allowed

if len({a.dtype for a in arrays}) > 1:
raise ValueError("meshgrid inputs must all have the same dtype")

return [
Array._new(array)
for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
Expand Down
7 changes: 7 additions & 0 deletions numpy/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
import numpy as np


# Note: astype is a function, not an array method as in NumPy.
def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array:
if not copy and dtype == x.dtype:
return x
return Array._new(x._array.astype(dtype=dtype, copy=copy))


def broadcast_arrays(*arrays: Array) -> List[Array]:
"""
Array API compatible wrapper for :py:func:`np.broadcast_arrays <numpy.broadcast_arrays>`.
Expand Down
1 change: 1 addition & 0 deletions numpy/array_api/_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array:
"""
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.where(condition._array, x1._array, x2._array))
89 changes: 75 additions & 14 deletions numpy/array_api/_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,91 @@

from ._array_object import Array

from typing import Tuple, Union
from typing import NamedTuple

import numpy as np

# Note: np.unique() is split into four functions in the array API:
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
# to remove polymorphic return types).

def unique(
x: Array,
/,
*,
return_counts: bool = False,
return_index: bool = False,
return_inverse: bool = False,
) -> Union[Array, Tuple[Array, ...]]:
# Note: The various unique() functions are supposed to return multiple NaNs.
# This does not match the NumPy behavior, however, this is currently left as a
# TODO in this implementation as this behavior may be reverted in np.unique().
# See https://github.com/numpy/numpy/issues/20326.

# Note: The functions here return a namedtuple (np.unique() returns a normal
# tuple).

class UniqueAllResult(NamedTuple):
values: Array
indices: Array
inverse_indices: Array
counts: Array


class UniqueCountsResult(NamedTuple):
values: Array
counts: Array


class UniqueInverseResult(NamedTuple):
values: Array
inverse_indices: Array


def unique_all(x: Array, /) -> UniqueAllResult:
"""
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.

See its docstring for more information.
"""
res = np.unique(
x._array,
return_counts=True,
return_index=True,
return_inverse=True,
)

return UniqueAllResult(*[Array._new(i) for i in res])


def unique_counts(x: Array, /) -> UniqueCountsResult:
res = np.unique(
x._array,
return_counts=True,
return_index=False,
return_inverse=False,
)

return UniqueCountsResult(*[Array._new(i) for i in res])


def unique_inverse(x: Array, /) -> UniqueInverseResult:
"""
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.

See its docstring for more information.
"""
res = np.unique(
x._array,
return_counts=False,
return_index=False,
return_inverse=True,
)
return UniqueInverseResult(*[Array._new(i) for i in res])


def unique_values(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.

See its docstring for more information.
"""
res = np.unique(
x._array,
return_counts=return_counts,
return_index=return_index,
return_inverse=return_inverse,
return_counts=False,
return_index=False,
return_inverse=False,
)
if isinstance(res, tuple):
return tuple(Array._new(i) for i in res)
return Array._new(res)