Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an Array Protocol & improve static typing support #589

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 26 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,29 @@ repos:
rev: 23.7.0
hooks:
- id: black

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.0.0"
hooks:
- id: mypy
additional_dependencies: [typing_extensions>=4.4.0]
args:
- --ignore-missing-imports
- --config=pyproject.toml
files: ".*(_draft.*)$"
exclude: |
(?x)^(
.*creation_functions.py|
.*data_type_functions.py|
.*elementwise_functions.py|
.*fft.py|
.*indexing_functions.py|
.*linalg.py|
.*linear_algebra_functions.py|
.*manipulation_functions.py|
.*searching_functions.py|
.*set_functions.py|
.*sorting_functions.py|
.*statistical_functions.py|
.*utility_functions.py|
)$
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,15 @@ doc = [
requires = ["setuptools"]
build-backend = "setuptools.build_meta"


[tool.black]
line-length = 88


[tool.mypy]
python_version = "3.9"
mypy_path = "$MYPY_CONFIG_FILE_DIR/src/array_api_stubs/_draft/"
files = [
"src/array_api_stubs/_draft/**/*.py"
]
follow_imports = "silent"
126 changes: 63 additions & 63 deletions spec/draft/API_specification/array_object.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,47 +30,47 @@ Arithmetic Operators

A conforming implementation of the array API standard must provide and support an array object supporting the following Python arithmetic operators.

- ``+x``: :meth:`.array.__pos__`
- ``+x``: :meth:`.Array.__pos__`

- `operator.pos(x) <https://docs.python.org/3/library/operator.html#operator.pos>`_
- `operator.__pos__(x) <https://docs.python.org/3/library/operator.html#operator.__pos__>`_

- `-x`: :meth:`.array.__neg__`
- `-x`: :meth:`.Array.__neg__`

- `operator.neg(x) <https://docs.python.org/3/library/operator.html#operator.neg>`_
- `operator.__neg__(x) <https://docs.python.org/3/library/operator.html#operator.__neg__>`_

- `x1 + x2`: :meth:`.array.__add__`
- `x1 + x2`: :meth:`.Array.__add__`

- `operator.add(x1, x2) <https://docs.python.org/3/library/operator.html#operator.add>`_
- `operator.__add__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__add__>`_

- `x1 - x2`: :meth:`.array.__sub__`
- `x1 - x2`: :meth:`.Array.__sub__`

- `operator.sub(x1, x2) <https://docs.python.org/3/library/operator.html#operator.sub>`_
- `operator.__sub__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__sub__>`_

- `x1 * x2`: :meth:`.array.__mul__`
- `x1 * x2`: :meth:`.Array.__mul__`

- `operator.mul(x1, x2) <https://docs.python.org/3/library/operator.html#operator.mul>`_
- `operator.__mul__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__mul__>`_

- `x1 / x2`: :meth:`.array.__truediv__`
- `x1 / x2`: :meth:`.Array.__truediv__`

- `operator.truediv(x1,x2) <https://docs.python.org/3/library/operator.html#operator.truediv>`_
- `operator.__truediv__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__truediv__>`_

- `x1 // x2`: :meth:`.array.__floordiv__`
- `x1 // x2`: :meth:`.Array.__floordiv__`

- `operator.floordiv(x1, x2) <https://docs.python.org/3/library/operator.html#operator.floordiv>`_
- `operator.__floordiv__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__floordiv__>`_

- `x1 % x2`: :meth:`.array.__mod__`
- `x1 % x2`: :meth:`.Array.__mod__`

- `operator.mod(x1, x2) <https://docs.python.org/3/library/operator.html#operator.mod>`_
- `operator.__mod__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__mod__>`_

- `x1 ** x2`: :meth:`.array.__pow__`
- `x1 ** x2`: :meth:`.Array.__pow__`

- `operator.pow(x1, x2) <https://docs.python.org/3/library/operator.html#operator.pow>`_
- `operator.__pow__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__pow__>`_
Expand All @@ -82,7 +82,7 @@ Array Operators

A conforming implementation of the array API standard must provide and support an array object supporting the following Python array operators.

- `x1 @ x2`: :meth:`.array.__matmul__`
- `x1 @ x2`: :meth:`.Array.__matmul__`

- `operator.matmul(x1, x2) <https://docs.python.org/3/library/operator.html#operator.matmul>`_
- `operator.__matmul__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__matmul__>`_
Expand All @@ -94,34 +94,34 @@ Bitwise Operators

A conforming implementation of the array API standard must provide and support an array object supporting the following Python bitwise operators.

- `~x`: :meth:`.array.__invert__`
- `~x`: :meth:`.Array.__invert__`

- `operator.inv(x) <https://docs.python.org/3/library/operator.html#operator.inv>`_
- `operator.invert(x) <https://docs.python.org/3/library/operator.html#operator.invert>`_
- `operator.__inv__(x) <https://docs.python.org/3/library/operator.html#operator.__inv__>`_
- `operator.__invert__(x) <https://docs.python.org/3/library/operator.html#operator.__invert__>`_

- `x1 & x2`: :meth:`.array.__and__`
- `x1 & x2`: :meth:`.Array.__and__`

- `operator.and(x1, x2) <https://docs.python.org/3/library/operator.html#operator.and>`_
- `operator.__and__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__and__>`_

- `x1 | x2`: :meth:`.array.__or__`
- `x1 | x2`: :meth:`.Array.__or__`

- `operator.or(x1, x2) <https://docs.python.org/3/library/operator.html#operator.or>`_
- `operator.__or__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__or__>`_

- `x1 ^ x2`: :meth:`.array.__xor__`
- `x1 ^ x2`: :meth:`.Array.__xor__`

- `operator.xor(x1, x2) <https://docs.python.org/3/library/operator.html#operator.xor>`_
- `operator.__xor__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__xor__>`_

- `x1 << x2`: :meth:`.array.__lshift__`
- `x1 << x2`: :meth:`.Array.__lshift__`

- `operator.lshift(x1, x2) <https://docs.python.org/3/library/operator.html#operator.lshift>`_
- `operator.__lshift__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__lshift__>`_

- `x1 >> x2`: :meth:`.array.__rshift__`
- `x1 >> x2`: :meth:`.Array.__rshift__`

- `operator.rshift(x1, x2) <https://docs.python.org/3/library/operator.html#operator.rshift>`_
- `operator.__rshift__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__rshift__>`_
Expand All @@ -133,32 +133,32 @@ Comparison Operators

A conforming implementation of the array API standard must provide and support an array object supporting the following Python comparison operators.

- `x1 < x2`: :meth:`.array.__lt__`
- `x1 < x2`: :meth:`.Array.__lt__`

- `operator.lt(x1, x2) <https://docs.python.org/3/library/operator.html#operator.lt>`_
- `operator.__lt__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__lt__>`_

- `x1 <= x2`: :meth:`.array.__le__`
- `x1 <= x2`: :meth:`.Array.__le__`

- `operator.le(x1, x2) <https://docs.python.org/3/library/operator.html#operator.le>`_
- `operator.__le__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__le__>`_

- `x1 > x2`: :meth:`.array.__gt__`
- `x1 > x2`: :meth:`.Array.__gt__`

- `operator.gt(x1, x2) <https://docs.python.org/3/library/operator.html#operator.gt>`_
- `operator.__gt__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__gt__>`_

- `x1 >= x2`: :meth:`.array.__ge__`
- `x1 >= x2`: :meth:`.Array.__ge__`

- `operator.ge(x1, x2) <https://docs.python.org/3/library/operator.html#operator.ge>`_
- `operator.__ge__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__ge__>`_

- `x1 == x2`: :meth:`.array.__eq__`
- `x1 == x2`: :meth:`.Array.__eq__`

- `operator.eq(x1, x2) <https://docs.python.org/3/library/operator.html#operator.eq>`_
- `operator.__eq__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__eq__>`_

- `x1 != x2`: :meth:`.array.__ne__`
- `x1 != x2`: :meth:`.Array.__ne__`

- `operator.ne(x1, x2) <https://docs.python.org/3/library/operator.html#operator.ne>`_
- `operator.__ne__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__ne__>`_
Expand Down Expand Up @@ -251,13 +251,13 @@ Attributes
:toctree: generated
:template: property.rst

array.dtype
array.device
array.mT
array.ndim
array.shape
array.size
array.T
Array.dtype
Array.device
Array.mT
Array.ndim
Array.shape
Array.size
Array.T

-------------------------------------------------

Expand All @@ -271,37 +271,37 @@ Methods
:toctree: generated
:template: property.rst

array.__abs__
array.__add__
array.__and__
array.__array_namespace__
array.__bool__
array.__complex__
array.__dlpack__
array.__dlpack_device__
array.__eq__
array.__float__
array.__floordiv__
array.__ge__
array.__getitem__
array.__gt__
array.__index__
array.__int__
array.__invert__
array.__le__
array.__lshift__
array.__lt__
array.__matmul__
array.__mod__
array.__mul__
array.__ne__
array.__neg__
array.__or__
array.__pos__
array.__pow__
array.__rshift__
array.__setitem__
array.__sub__
array.__truediv__
array.__xor__
array.to_device
Array.__abs__
Array.__add__
Array.__and__
Array.__array_namespace__
Array.__bool__
Array.__complex__
Array.__dlpack__
Array.__dlpack_device__
Array.__eq__
Array.__float__
Array.__floordiv__
Array.__ge__
Array.__getitem__
Array.__gt__
Array.__index__
Array.__int__
Array.__invert__
Array.__le__
Array.__lshift__
Array.__lt__
Array.__matmul__
Array.__mod__
Array.__mul__
Array.__ne__
Array.__neg__
Array.__or__
Array.__pos__
Array.__pow__
Array.__rshift__
Array.__setitem__
Array.__sub__
Array.__truediv__
Array.__xor__
Array.to_device
2 changes: 1 addition & 1 deletion spec/draft/purpose_and_scope.md
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ namespace (e.g. `import package_name.array_api`). This has two issues though:

To address both issues, a uniform way must be provided by a conforming
implementation to access the API namespace, namely a [method on the array
object](array.__array_namespace__):
object](Array.__array_namespace__):

```
xp = x.__array_namespace__()
Expand Down
4 changes: 4 additions & 0 deletions src/_array_api_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@
]
nitpick_ignore_regex = [
("py:class", ".*array"),
("py:class", ".*Array"),
("py:class", ".*device"),
("py:class", ".*Device"),
("py:class", ".*dtype"),
("py:class", ".*Self"),
("py:class", ".*NestedSequence"),
("py:class", ".*SupportsBufferProtocol"),
("py:class", ".*PyCapsule"),
Expand All @@ -77,6 +80,7 @@
"array": "array",
"Device": "device",
"Dtype": "dtype",
"DType": "dtype",
}

# Make autosummary show the signatures of functions in the tables using actual
Expand Down
24 changes: 19 additions & 5 deletions src/array_api_stubs/_draft/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
List,
Literal,
Expand All @@ -41,9 +42,22 @@
)
from enum import Enum

array = TypeVar("array")
device = TypeVar("device")
dtype = TypeVar("dtype")

if TYPE_CHECKING:
from .array_object import Array
from .data_types import DType


class Device(Protocol):
"""Protocol for device objects."""

def __eq__(self, value: Any) -> bool:
...


array = TypeVar("array", bound="Array")
device = TypeVar("device", bound=Device)
dtype = TypeVar("dtype", bound="DType")
SupportsDLPack = TypeVar("SupportsDLPack")
SupportsBufferProtocol = TypeVar("SupportsBufferProtocol")
PyCapsule = TypeVar("PyCapsule")
Expand All @@ -61,7 +75,7 @@ class finfo_object:
max: float
min: float
smallest_normal: float
dtype: dtype
dtype: DType


@dataclass
Expand All @@ -71,7 +85,7 @@ class iinfo_object:
bits: int
max: int
min: int
dtype: dtype
dtype: DType


_T_co = TypeVar("_T_co", covariant=True)
Expand Down