Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Implement the DLPack Array API protocols for ndarray. #19083

Merged
merged 18 commits into from Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 4 additions & 3 deletions doc/neps/nep-0047-array-api-standard.rst
Expand Up @@ -338,9 +338,10 @@ the options already present in NumPy are:

Adding support for DLPack to NumPy entails:

- Adding a ``ndarray.__dlpack__`` method.
- Adding a ``from_dlpack`` function, which takes as input an object
supporting ``__dlpack__``, and returns an ``ndarray``.
- Adding a ``ndarray.__dlpack__()`` method which returns a ``dlpack`` C
structure wrapped in a ``PyCapsule``.
- Adding a ``np._from_dlpack(obj)`` function, where ``obj`` supports
``__dlpack__()``, and returns an ``ndarray``.

DLPack is currently a ~200 LoC header, and is meant to be included directly, so
no external dependency is needed. Implementation should be straightforward.
Expand Down
6 changes: 6 additions & 0 deletions doc/release/upcoming_changes/19083.new_feature.rst
@@ -0,0 +1,6 @@
Add NEP 47-compatible dlpack support
------------------------------------

Add a ``ndarray.__dlpack__()`` method which returns a ``dlpack`` C structure
wrapped in a ``PyCapsule``. Also add a ``np._from_dlpack(obj)`` function, where
``obj`` supports ``__dlpack__()``, and returns an ``ndarray``.
17 changes: 17 additions & 0 deletions numpy/__init__.pyi
Expand Up @@ -1413,6 +1413,7 @@ _SupportsBuffer = Union[

_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
_T_contra = TypeVar("_T_contra", contravariant=True)
_2Tuple = Tuple[_T, _T]
_CastingKind = L["no", "equiv", "safe", "same_kind", "unsafe"]

Expand All @@ -1432,6 +1433,10 @@ _ArrayTD64_co = NDArray[Union[bool_, integer[Any], timedelta64]]
# Introduce an alias for `dtype` to avoid naming conflicts.
_dtype = dtype

# `builtins.PyCapsule` unfortunately lacks annotations as of the moment;
# use `Any` as a stopgap measure
_PyCapsule = Any
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved

Comment on lines +1436 to +1439
Copy link
Member

@BvB93 BvB93 May 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interestingly, if we could get builtins.PyCapsule in typeshed annotated as a parameterizable type,
then it would in principle be possible for static type checkers to read, e.g., the underlying shape and dtype.

The only thing that users or libraries would have to do here is declare the necessary annotations.

Perhaps something to consider for the future?

Examples

from typing import TypeVar, Any, Generic, Tuple as Tuple
import numpy as np

# Improvised `PyCapsule` annotation
_T = TypeVar("_T")
class PyCapsule(Generic[_T]): ...

# Construct a more compact `PyCapsule` alias; `Tuple` used herein to introduce 2 parameters 
# (there may be more appropriate types that can fulfill this functionality)
_Shape = TypeVar("_Shape", bound=Any)  # TODO: Wait for PEP 646's TypeVarTuple
_DType = TypeVar("_DType", bound=np.dtype[Any])
DLPackCapsule = PyCapsule[Tuple[_Shape, _DType]]

# A practical example
def from_dlpack(__x: DLPackCapsule[_Shape, _DType]) -> np.ndarray[_Shape, _DType]: ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll consider this as out of scope of this PR for now, but will leave the conversation unresolved for visibility.

class _SupportsItem(Protocol[_T_co]):
def item(self, args: Any, /) -> _T_co: ...

Expand Down Expand Up @@ -2439,6 +2444,12 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
def __ior__(self: NDArray[signedinteger[_NBit1]], other: _ArrayLikeInt_co) -> NDArray[signedinteger[_NBit1]]: ...
@overload
def __ior__(self: NDArray[object_], other: Any) -> NDArray[object_]: ...
@overload
def __ior__(self: NDArray[_ScalarType], other: _RecursiveSequence) -> NDArray[_ScalarType]: ...
@overload
def __dlpack__(self: NDArray[number[Any]], *, stream: None = ...) -> _PyCapsule: ...
@overload
def __dlpack_device__(self) -> Tuple[int, L[0]]: ...

# Keep `dtype` at the bottom to avoid name conflicts with `np.dtype`
@property
Expand Down Expand Up @@ -4320,3 +4331,9 @@ class chararray(ndarray[_ShapeType, _CharDType]):

# NOTE: Deprecated
# class MachAr: ...

class _SupportsDLPack(Protocol[_T_contra]):
def __dlpack__(self, *, stream: None | _T_contra = ...) -> _PyCapsule: ...

def _from_dlpack(__obj: _SupportsDLPack[None]) -> NDArray[Any]: ...

4 changes: 2 additions & 2 deletions numpy/array_api/__init__.py
Expand Up @@ -136,7 +136,7 @@
empty,
empty_like,
eye,
from_dlpack,
_from_dlpack,
full,
full_like,
linspace,
Expand All @@ -155,7 +155,7 @@
"empty",
"empty_like",
"eye",
"from_dlpack",
"_from_dlpack",
"full",
"full_like",
"linspace",
Expand Down
2 changes: 1 addition & 1 deletion numpy/array_api/_creation_functions.py
Expand Up @@ -151,7 +151,7 @@ def eye(
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))


def from_dlpack(x: object, /) -> Array:
def _from_dlpack(x: object, /) -> Array:
# Note: dlpack support is not yet implemented on Array
raise NotImplementedError("DLPack support is not yet implemented")

Expand Down
22 changes: 22 additions & 0 deletions numpy/core/_add_newdocs.py
Expand Up @@ -1573,6 +1573,19 @@
array_function_like_doc,
))

add_newdoc('numpy.core.multiarray', '_from_dlpack',
"""
_from_dlpack(x, /)

Create a NumPy array from an object implementing the ``__dlpack__``
protocol.

See Also
--------
`Array API documentation
<https://data-apis.org/array-api/latest/design_topics/data_interchange.html#syntax-for-data-interchange-with-dlpack>`_
""")

add_newdoc('numpy.core', 'fastCopyAndTranspose',
"""_fastCopyAndTranspose(a)""")

Expand Down Expand Up @@ -2263,6 +2276,15 @@
add_newdoc('numpy.core.multiarray', 'ndarray', ('__array_struct__',
"""Array protocol: C-struct side."""))

add_newdoc('numpy.core.multiarray', 'ndarray', ('__dlpack__',
"""a.__dlpack__(*, stream=None)

DLPack Protocol: Part of the Array API."""))

add_newdoc('numpy.core.multiarray', 'ndarray', ('__dlpack_device__',
"""a.__dlpack_device__()

DLPack Protocol: Part of the Array API."""))

add_newdoc('numpy.core.multiarray', 'ndarray', ('base',
"""
Expand Down
1 change: 1 addition & 0 deletions numpy/core/code_generators/genapi.py
Expand Up @@ -41,6 +41,7 @@
join('multiarray', 'datetime_busdaycal.c'),
join('multiarray', 'datetime_strings.c'),
join('multiarray', 'descriptor.c'),
join('multiarray', 'dlpack.c'),
join('multiarray', 'dtypemeta.c'),
join('multiarray', 'einsum.c.src'),
join('multiarray', 'flagsobject.c'),
Expand Down
22 changes: 12 additions & 10 deletions numpy/core/multiarray.py
Expand Up @@ -14,27 +14,28 @@
# do not change them. issue gh-15518
# _get_ndarray_c_version is semi-public, on purpose not added to __all__
from ._multiarray_umath import (
_fastCopyAndTranspose, _flagdict, _insert, _reconstruct, _vec_string,
_ARRAY_API, _monotonicity, _get_ndarray_c_version, _set_madvise_hugepage,
_fastCopyAndTranspose, _flagdict, _from_dlpack, _insert, _reconstruct,
_vec_string, _ARRAY_API, _monotonicity, _get_ndarray_c_version,
_set_madvise_hugepage,
)

__all__ = [
'_ARRAY_API', 'ALLOW_THREADS', 'BUFSIZE', 'CLIP', 'DATETIMEUNITS',
'ITEM_HASOBJECT', 'ITEM_IS_POINTER', 'LIST_PICKLE', 'MAXDIMS',
'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT', 'NEEDS_INIT', 'NEEDS_PYAPI',
'RAISE', 'USE_GETITEM', 'USE_SETITEM', 'WRAP', '_fastCopyAndTranspose',
'_flagdict', '_insert', '_reconstruct', '_vec_string', '_monotonicity',
'add_docstring', 'arange', 'array', 'asarray', 'asanyarray',
'ascontiguousarray', 'asfortranarray', 'bincount', 'broadcast',
'busday_count', 'busday_offset', 'busdaycalendar', 'can_cast',
'_flagdict', '_from_dlpack', '_insert', '_reconstruct', '_vec_string',
'_monotonicity', 'add_docstring', 'arange', 'array', 'asarray',
'asanyarray', 'ascontiguousarray', 'asfortranarray', 'bincount',
'broadcast', 'busday_count', 'busday_offset', 'busdaycalendar', 'can_cast',
'compare_chararrays', 'concatenate', 'copyto', 'correlate', 'correlate2',
'count_nonzero', 'c_einsum', 'datetime_as_string', 'datetime_data',
'dot', 'dragon4_positional', 'dragon4_scientific', 'dtype',
'empty', 'empty_like', 'error', 'flagsobj', 'flatiter', 'format_longfloat',
'frombuffer', 'fromfile', 'fromiter', 'fromstring', 'get_handler_name',
'inner', 'interp', 'interp_complex', 'is_busday', 'lexsort',
'matmul', 'may_share_memory', 'min_scalar_type', 'ndarray', 'nditer',
'nested_iters', 'normalize_axis_index', 'packbits',
'frombuffer', 'fromfile', 'fromiter', 'fromstring',
'get_handler_name', 'inner', 'interp', 'interp_complex', 'is_busday',
'lexsort', 'matmul', 'may_share_memory', 'min_scalar_type', 'ndarray',
'nditer', 'nested_iters', 'normalize_axis_index', 'packbits',
'promote_types', 'putmask', 'ravel_multi_index', 'result_type', 'scalar',
'set_datetimeparse_function', 'set_legacy_print_mode', 'set_numeric_ops',
'set_string_function', 'set_typeDict', 'shares_memory',
Expand All @@ -46,6 +47,7 @@
scalar.__module__ = 'numpy.core.multiarray'


_from_dlpack.__module__ = 'numpy'
arange.__module__ = 'numpy'
array.__module__ = 'numpy'
asarray.__module__ = 'numpy'
Expand Down
6 changes: 3 additions & 3 deletions numpy/core/numeric.py
Expand Up @@ -13,8 +13,8 @@
WRAP, arange, array, asarray, asanyarray, ascontiguousarray,
asfortranarray, broadcast, can_cast, compare_chararrays,
concatenate, copyto, dot, dtype, empty,
empty_like, flatiter, frombuffer, fromfile, fromiter, fromstring,
inner, lexsort, matmul, may_share_memory,
empty_like, flatiter, frombuffer, _from_dlpack, fromfile, fromiter,
fromstring, inner, lexsort, matmul, may_share_memory,
min_scalar_type, ndarray, nditer, nested_iters, promote_types,
putmask, result_type, set_numeric_ops, shares_memory, vdot, where,
zeros, normalize_axis_index)
Expand All @@ -41,7 +41,7 @@
'newaxis', 'ndarray', 'flatiter', 'nditer', 'nested_iters', 'ufunc',
'arange', 'array', 'asarray', 'asanyarray', 'ascontiguousarray',
'asfortranarray', 'zeros', 'count_nonzero', 'empty', 'broadcast', 'dtype',
'fromstring', 'fromfile', 'frombuffer', 'where',
'fromstring', 'fromfile', 'frombuffer', '_from_dlpack', 'where',
'argwhere', 'copyto', 'concatenate', 'fastCopyAndTranspose', 'lexsort',
'set_numeric_ops', 'can_cast', 'promote_types', 'min_scalar_type',
'result_type', 'isfortran', 'empty_like', 'zeros_like', 'ones_like',
Expand Down
3 changes: 3 additions & 0 deletions numpy/core/setup.py
Expand Up @@ -740,6 +740,7 @@ def gl_if_msvc(build_cmd):
#######################################################################

common_deps = [
join('src', 'common', 'dlpack', 'dlpack.h'),
join('src', 'common', 'array_assign.h'),
join('src', 'common', 'binop_override.h'),
join('src', 'common', 'cblasfuncs.h'),
Expand All @@ -749,6 +750,7 @@ def gl_if_msvc(build_cmd):
join('src', 'common', 'npy_cblas.h'),
join('src', 'common', 'npy_config.h'),
join('src', 'common', 'npy_ctypes.h'),
join('src', 'common', 'npy_dlpack.h'),
join('src', 'common', 'npy_extint128.h'),
join('src', 'common', 'npy_import.h'),
join('src', 'common', 'npy_hashtable.h'),
Expand Down Expand Up @@ -881,6 +883,7 @@ def gl_if_msvc(build_cmd):
join('src', 'multiarray', 'datetime_busday.c'),
join('src', 'multiarray', 'datetime_busdaycal.c'),
join('src', 'multiarray', 'descriptor.c'),
join('src', 'multiarray', 'dlpack.c'),
join('src', 'multiarray', 'dtypemeta.c'),
join('src', 'multiarray', 'dragon4.c'),
join('src', 'multiarray', 'dtype_transfer.c'),
Expand Down