Skip to content

Commit

Permalink
apply diff from numpy/numpy#20066
Browse files Browse the repository at this point in the history
  • Loading branch information
asmeurer authored and leofang committed Nov 12, 2021
1 parent 05a4cea commit 797b2af
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 34 deletions.
6 changes: 4 additions & 2 deletions cupy/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@
]

from ._data_type_functions import (
astype,
broadcast_arrays,
broadcast_to,
can_cast,
Expand All @@ -185,6 +186,7 @@
)

__all__ += [
"astype",
"broadcast_arrays",
"broadcast_to",
"can_cast",
Expand Down Expand Up @@ -365,9 +367,9 @@

__all__ += ["argmax", "argmin", "nonzero", "where"]

from ._set_functions import unique
from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values

__all__ += ["unique"]
__all__ += ["unique_all", "unique_counts", "unique_inverse", "unique_values"]

from ._sorting_functions import argsort, sort

Expand Down
39 changes: 28 additions & 11 deletions cupy/array_api/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from typing import TYPE_CHECKING, Optional, Tuple, Union, Any

if TYPE_CHECKING:
from ._typing import PyCapsule, Device, Dtype
from ._typing import Any, PyCapsule, Device, Dtype

import cupy as np
from cupy.cuda import Device as _Device
Expand Down Expand Up @@ -101,7 +101,14 @@ def __repr__(self: Array, /) -> str:
"""
Performs the operation __repr__.
"""
return f"Array({repr(self._array)}, dtype={self.dtype.name})"
suffix = f", dtype={self.dtype.name})"
if 0 in self.shape:
prefix = "empty("
mid = str(self.shape)
else:
prefix = "Array("
mid = np.array2string(np.asnumpy(self._array), separator=', ', prefix=prefix, suffix=suffix)
return prefix + mid + suffix

# These are various helper functions to make the array behavior match the
# spec in places where it either deviates from or is more strict than
Expand Down Expand Up @@ -243,6 +250,10 @@ def _validate_index(key, shape):
The following cases are allowed by NumPy, but not specified by the array
API specification:
- Indices to not include an implicit ellipsis at the end. That is,
every axis of an array must be explicitly indexed or an ellipsis
included.
- The start and stop of a slice may not be out of bounds. In
particular, for a slice ``i:j:k`` on an axis of size ``n``, only the
following are allowed:
Expand All @@ -269,14 +280,18 @@ def _validate_index(key, shape):
return key
if shape == ():
return key
if len(shape) > 1:
raise IndexError(
"Multidimensional arrays must include an index for every axis or use an ellipsis"
)
size = shape[0]
# Ensure invalid slice entries are passed through.
if key.start is not None:
try:
operator.index(key.start)
except TypeError:
return key
if not (-size <= key.start <= max(0, size - 1)):
if not (-size <= key.start <= size):
raise IndexError(
"Slices with out-of-bounds start are not allowed in the array API namespace"
)
Expand Down Expand Up @@ -321,6 +336,10 @@ def _validate_index(key, shape):
zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])
):
Array._validate_index(idx, (size,))
if n_ellipsis == 0 and len(key) < len(shape):
raise IndexError(
"Multidimensional arrays must include an index for every axis or use an ellipsis"
)
return key
elif isinstance(key, bool):
return key
Expand All @@ -338,7 +357,12 @@ def _validate_index(key, shape):
"newaxis indices are not allowed in the array API namespace"
)
try:
return operator.index(key)
key = operator.index(key)
if shape is not None and len(shape) > 1:
raise IndexError(
"Multidimensional arrays must include an index for every axis or use an ellipsis"
)
return key
except TypeError:
# Note: This also omits boolean arrays that are not already in
# Array() form, like a list of booleans.
Expand Down Expand Up @@ -526,13 +550,6 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
res = self._array.__le__(other._array)
return self.__class__._new(res)

# Note: __len__ may end up being removed from the array API spec.
def __len__(self, /) -> int:
"""
Performs the operation __len__.
"""
return self._array.__len__()

def __lshift__(self: Array, other: Union[int, Array], /) -> Array:
"""
Performs the operation __lshift__.
Expand Down
8 changes: 6 additions & 2 deletions cupy/array_api/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Device,
Dtype,
NestedSequence,
SupportsDLPack,
SupportsBufferProtocol,
)
from collections.abc import Sequence
Expand Down Expand Up @@ -38,7 +37,6 @@ def asarray(
int,
float,
NestedSequence[bool | int | float],
SupportsDLPack,
SupportsBufferProtocol,
],
/,
Expand Down Expand Up @@ -307,6 +305,12 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
"""
from ._array_object import Array

# Note: unlike np.meshgrid, only inputs with all the same dtype are
# allowed

if len({a.dtype for a in arrays}) > 1:
raise ValueError("meshgrid inputs must all have the same dtype")

return [
Array._new(array)
for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
Expand Down
7 changes: 7 additions & 0 deletions cupy/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
import cupy as np


# Note: astype is a function, not an array method as in NumPy.
def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array:
if not copy and dtype == x.dtype:
return x
return Array._new(x._array.astype(dtype=dtype, copy=copy))


def broadcast_arrays(*arrays: Array) -> List[Array]:
"""
Array API compatible wrapper for :py:func:`np.broadcast_arrays <numpy.broadcast_arrays>`.
Expand Down
1 change: 1 addition & 0 deletions cupy/array_api/_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array:
"""
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.where(condition._array, x1._array, x2._array))
89 changes: 75 additions & 14 deletions cupy/array_api/_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,91 @@

from ._array_object import Array

from typing import Tuple, Union
from typing import NamedTuple

import cupy as np

# Note: np.unique() is split into four functions in the array API:
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
# to remove polymorphic return types).

def unique(
x: Array,
/,
*,
return_counts: bool = False,
return_index: bool = False,
return_inverse: bool = False,
) -> Union[Array, Tuple[Array, ...]]:
# Note: The various unique() functions are supposed to return multiple NaNs.
# This does not match the NumPy behavior, however, this is currently left as a
# TODO in this implementation as this behavior may be reverted in np.unique().
# See https://github.com/numpy/numpy/issues/20326.

# Note: The functions here return a namedtuple (np.unique() returns a normal
# tuple).

class UniqueAllResult(NamedTuple):
values: Array
indices: Array
inverse_indices: Array
counts: Array


class UniqueCountsResult(NamedTuple):
values: Array
counts: Array


class UniqueInverseResult(NamedTuple):
values: Array
inverse_indices: Array


def unique_all(x: Array, /) -> UniqueAllResult:
"""
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
See its docstring for more information.
"""
res = np.unique(
x._array,
return_counts=True,
return_index=True,
return_inverse=True,
)

return UniqueAllResult(*[Array._new(i) for i in res])


def unique_counts(x: Array, /) -> UniqueCountsResult:
res = np.unique(
x._array,
return_counts=True,
return_index=False,
return_inverse=False,
)

return UniqueCountsResult(*[Array._new(i) for i in res])


def unique_inverse(x: Array, /) -> UniqueInverseResult:
"""
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
See its docstring for more information.
"""
res = np.unique(
x._array,
return_counts=False,
return_index=False,
return_inverse=True,
)
return UniqueInverseResult(*[Array._new(i) for i in res])


def unique_values(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
See its docstring for more information.
"""
res = np.unique(
x._array,
return_counts=return_counts,
return_index=return_index,
return_inverse=return_inverse,
return_counts=False,
return_index=False,
return_inverse=False,
)
if isinstance(res, tuple):
return tuple(Array._new(i) for i in res)
return Array._new(res)
15 changes: 10 additions & 5 deletions tests/cupy_tests/array_api_tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[:-4])
assert_raises(IndexError, lambda: a[:3:-1])
assert_raises(IndexError, lambda: a[:-5:-1])
assert_raises(IndexError, lambda: a[3:])
assert_raises(IndexError, lambda: a[4:])
assert_raises(IndexError, lambda: a[-4:])
assert_raises(IndexError, lambda: a[3::-1])
assert_raises(IndexError, lambda: a[4::-1])
assert_raises(IndexError, lambda: a[-4::-1])

assert_raises(IndexError, lambda: a[...,:5])
assert_raises(IndexError, lambda: a[...,:-5])
assert_raises(IndexError, lambda: a[...,:4:-1])
assert_raises(IndexError, lambda: a[...,:5:-1])
assert_raises(IndexError, lambda: a[...,:-6:-1])
assert_raises(IndexError, lambda: a[...,4:])
assert_raises(IndexError, lambda: a[...,5:])
assert_raises(IndexError, lambda: a[...,-5:])
assert_raises(IndexError, lambda: a[...,4::-1])
assert_raises(IndexError, lambda: a[...,5::-1])
assert_raises(IndexError, lambda: a[...,-5::-1])

# Boolean indices cannot be part of a larger tuple index
Expand All @@ -74,6 +74,11 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[None, ...])
assert_raises(IndexError, lambda: a[..., None])

# Multiaxis indices must contain exactly as many indices as dimensions
assert_raises(IndexError, lambda: a[()])
assert_raises(IndexError, lambda: a[0,])
assert_raises(IndexError, lambda: a[0])
assert_raises(IndexError, lambda: a[:])

def test_operators():
# For every operator, we test that it works for the required type
Expand Down
8 changes: 8 additions & 0 deletions tests/cupy_tests/array_api_tests/test_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,11 @@ def test_zeros_like_errors():
assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype=int))
assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype="i"))
zeros_like(asarray(1), device=Device()) # on current device

def test_meshgrid_dtype_errors():
# Doesn't raise
meshgrid()
meshgrid(asarray([1.], dtype=float32))
meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float32))

assert_raises(ValueError, lambda: meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float64)))

0 comments on commit 797b2af

Please sign in to comment.