Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: A few updates to the array_api #20066

Merged
merged 23 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
48005c1
Allow casting in the array API asarray()
asmeurer Oct 1, 2021
9af4814
Restrict multidimensional indexing in the array API namespace
asmeurer Oct 7, 2021
cb0b9c6
Fix type promotion for numpy.array_api.where
asmeurer Oct 19, 2021
7da7a99
Print empty array_api arrays using empty()
asmeurer Oct 20, 2021
0169af7
Fix an incorrect slice bounds guard in the array API
asmeurer Oct 22, 2021
e0589fd
Disallow multiple different dtypes in the input to np.array_api.meshgrid
asmeurer Nov 5, 2021
21faf34
Remove DLPack support from numpy.array_api.asarray()
asmeurer Nov 5, 2021
d0f591d
Remove __len__ from the array API array object
asmeurer Nov 5, 2021
17d7886
Add astype() to numpy.array_api
asmeurer Nov 8, 2021
7d9edf3
Update the unique_* functions in numpy.array_api
asmeurer Nov 8, 2021
cb335d2
Add the stream argument to the array API to_device method
asmeurer Nov 8, 2021
5cae94d
Use the NamedTuple classes for the type signatures
asmeurer Nov 8, 2021
f6053aa
Add unique_counts to the array API namespace
asmeurer Nov 8, 2021
680e0a4
Remove some unused imports
asmeurer Nov 8, 2021
49a3cc9
Update the array_api indexing restrictions
asmeurer Nov 9, 2021
475b01d
Merge branch 'main' into array-api-updates2
asmeurer Nov 9, 2021
6352d11
Use a simpler type annotation for the array API to_device method
asmeurer Nov 10, 2021
b7bf06e
Fix a typo
asmeurer Nov 10, 2021
61bc679
Fix a test failure in the array_api submodule
asmeurer Nov 10, 2021
f72ed85
Merge branch 'array-api-updates2' of github.com:asmeurer/numpy into a…
asmeurer Nov 10, 2021
09aa2b5
Merge branch 'main' into array-api-updates2
asmeurer Nov 10, 2021
580a616
Add dlpack support to the array_api submodule
asmeurer Nov 11, 2021
b2af8e9
Merge branch 'main' into array-api-updates2
rgommers Nov 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions numpy/array_api/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ def _validate_index(key, shape):
The following cases are allowed by NumPy, but not specified by the array
API specification:

- Multidimensional (tuple) indices must include an index for every
axis of the array unless an ellipsis is included.

- The start and stop of a slice may not be out of bounds. In
particular, for a slice ``i:j:k`` on an axis of size ``n``, only the
following are allowed:
Expand Down Expand Up @@ -322,6 +325,10 @@ def _validate_index(key, shape):
zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])
):
Array._validate_index(idx, (size,))
if n_ellipsis == 0 and len(key) < len(shape):
raise IndexError(
"Multidimensional indices must either index every axis of the array or use an ellipsis"
rgommers marked this conversation as resolved.
Show resolved Hide resolved
)
return key
elif isinstance(key, bool):
return key
Expand Down
4 changes: 3 additions & 1 deletion numpy/array_api/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def asarray(
if copy is False:
# Note: copy=False is not yet implemented in np.asarray
raise NotImplementedError("copy=False is not yet implemented")
if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype):
if isinstance(obj, Array):
if dtype is not None and obj.dtype != dtype:
copy = True
if copy is True:
return Array._new(np.array(obj._array, copy=True, dtype=dtype))
return obj
Expand Down
3 changes: 3 additions & 0 deletions numpy/array_api/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[None, ...])
assert_raises(IndexError, lambda: a[..., None])

# Multiaxis indices must contain exactly as many indices as dimensions
assert_raises(IndexError, lambda: a[()])
assert_raises(IndexError, lambda: a[0,])
rgommers marked this conversation as resolved.
Show resolved Hide resolved

def test_operators():
# For every operator, we test that it works for the required type
Expand Down