Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use NDArray instead of ArrayLike when dtype is given #442

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

yosh-matsuda
Copy link
Contributor

For ndarray typing, I would like to suggest using NDArray instead of ArrayLike when the dtype is given.

The NDArray has data type annotation while ArrayLike appears to treat the data type as Any.

@yosh-matsuda
Copy link
Contributor Author

I am not sure why the test for pypy3.10 failed, but all the tests passed on my fork.
https://github.com/yosh-matsuda/nanobind/actions/runs/8123702629

@wjakob
Copy link
Owner

wjakob commented Mar 3, 2024

Doesn't NDArray imply that this is actually a NumPy array? Whereas ArrayLike is a bit more loose ("could in principle be converted into a numpy array").

@wjakob
Copy link
Owner

wjakob commented Mar 3, 2024

I'm actually not set on numpy.typing necessarily being the best kind of type to use here, perhaps there are other options as well? This one e.g., seems interesting: https://github.com/patrick-kidger/jaxtyping

@yosh-matsuda
Copy link
Contributor Author

yosh-matsuda commented Mar 5, 2024

Sorry, I had assumed that stubgen was annotating ndarray for Numpy. However, I think ArrayLike is inappropriate because it would include not only ndarray, but also types for which numpy.ndarray can be constructed, i.e. Python Sequence and Scalar.

module.def("ndarray_func1", [](nb::ndarray<std::int32_t> arr) {});
def ndarray_func1(arg: Annotated[ArrayLike, dict(dtype='int32')], /) -> None:
>>> ndarray_func1([1,2,3])	# type checking OK, but...
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: ndarray_func1(): incompatible function arguments. The following argument types are supported:
    1. ndarray_func1(arg: ndarray[dtype=int32], /) -> None

Invoked with types: list

On the other hand, as you say, NDArray seems to be compatible only with numpy.

import torch
import test

test.ndarray_func1(torch.tensor([1, 2, 3]))

mypy

test.py:5: error: Argument 1 to "ndarray_func1" has incompatible type "Tensor"; expected "ndarray[Any, dtype[signedinteger[_32Bit]]]"  [arg-type]

I am currently trying to use jaxtyping, but I have not successfully checked multiple array types at once with dtype specified.

@yosh-matsuda
Copy link
Contributor Author

@wjakob

Since numpy is the only user module in nanobind that can be imported, what about the following idea about post-processing in stubgen?

When ndarray framework is specified:

  • Annotated[<framework_type_name>, dict(...)] for input and output
  • For numpy numpy.NDArray is used with dtype

Framework is not specified for input arguments:

  • Specify the array type to annotate for ndarray in the stubgen argument.
    • stubgen --numpy --jax --torch --tensorflow (tentative naming) means nb::ndaray will be annotated with like Annotated[np.NDArray[...] | jax.Array | torch.Tensor | tensorflow.(?), dict(...)].
  • numpy.NDArray is enabled for default

Framework is not specified for return values:

  • Raw "ndarray" will be used with Annotated

@wjakob
Copy link
Owner

wjakob commented Mar 8, 2024

I think that your answer applies to the output end (A function returning an nd-array in a specific framework).

On the input end, the situation is rather more complex. Nanobind will accept anything that implements the buffer protocol or DLPack protocol as input. That could be encoded as follows:

# Contents of a hypothetical nanobind.typing module

from collections.abc import Buffer
from typing import Protocol, Any, TypeAlias, Union
from types import CapsuleType


class DLPackBuffer(Protocol):
    def __dlpack__(
        self,
        stream: Any = None,
        max_version: tuple[int, int] | None = None,
        dl_device: Any | None = None,
        copy: bool | None = None,
    ) -> CapsuleType: ...


NDArray: TypeAlias = Union[Buffer, DLPackBuffer]

@yosh-matsuda yosh-matsuda marked this pull request as draft March 9, 2024 07:54
@yosh-matsuda yosh-matsuda force-pushed the typing-ndarray branch 2 times, most recently from 24b888a to cf0dd0e Compare March 9, 2024 08:34
@yosh-matsuda yosh-matsuda marked this pull request as ready for review March 9, 2024 11:52
@yosh-matsuda
Copy link
Contributor Author

@wjakob Could you review the last commit (force pushed) in this PR? The changes are as follows:

  • Add array protocol class NDArray for nd-array in stub file based on your suggestion.
    • Since implementations of __dlpack__ in various frameworks do not seem to strictly follow Python array API standard, the protocol DLPackBuffer.__dlpack__ has no arguments.
    • I checked that numpy, torch, jax arrays are accepted but not tensorflow.
      (I am not sure if tensorflow.python.framework.ops.EagerTensor has __dlpack__)
  • If the framework is specified in nb::ndarray, the framework type will be Annotated with metadata.
    • numpy.ndarray will be replaced with numpy.typing.NDArray[<dtype>].
  • If not, the typing of nb::ndarray will be Annotated[NDArray, meta].

Please see the example stub file in tests: https://github.com/wjakob/nanobind/blob/ba12ce8ce410bdea65bf03f205f3e8d990019150/tests/test_ndarray_ext.pyi.ref

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants