Skip to content

Commit

Permalink
Merge pull request #22357 from BvB93/runtime
Browse files Browse the repository at this point in the history
TYP,ENH: Mark `numpy.typing` protocols as runtime checkable
  • Loading branch information
charris committed Oct 2, 2022
2 parents b654752 + 51c9aa8 commit 5490d87
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 3 deletions.
6 changes: 6 additions & 0 deletions doc/release/upcoming_changes/22357.improvement.rst
@@ -0,0 +1,6 @@
``numpy.typing`` protocols are now runtime checkable
----------------------------------------------------

The protocols used in `~numpy.typing.ArrayLike` and `~numpy.typing.DTypeLike`
are now properly marked as runtime checkable, making them easier to use for
runtime type checkers.
6 changes: 4 additions & 2 deletions numpy/_typing/_array_like.py
Expand Up @@ -3,7 +3,7 @@
# NOTE: Import `Sequence` from `typing` as we it is needed for a type-alias,
# not an annotation
from collections.abc import Collection, Callable
from typing import Any, Sequence, Protocol, Union, TypeVar
from typing import Any, Sequence, Protocol, Union, TypeVar, runtime_checkable
from numpy import (
ndarray,
dtype,
Expand Down Expand Up @@ -33,10 +33,12 @@
# array.
# Concrete implementations of the protocol are responsible for adding
# any and all remaining overloads
@runtime_checkable
class _SupportsArray(Protocol[_DType_co]):
def __array__(self) -> ndarray[Any, _DType_co]: ...


@runtime_checkable
class _SupportsArrayFunc(Protocol):
"""A protocol class representing `~class.__array_function__`."""
def __array_function__(
Expand Down Expand Up @@ -146,7 +148,7 @@ def __array_function__(
# Used as the first overload, should only match NDArray[Any],
# not any actual types.
# https://github.com/numpy/numpy/pull/22193
class _UnknownType:
class _UnknownType:
...


Expand Down
2 changes: 2 additions & 0 deletions numpy/_typing/_dtype_like.py
Expand Up @@ -8,6 +8,7 @@
TypeVar,
Protocol,
TypedDict,
runtime_checkable,
)

import numpy as np
Expand Down Expand Up @@ -80,6 +81,7 @@ class _DTypeDict(_DTypeDictBase, total=False):


# A protocol for anything with the dtype attribute
@runtime_checkable
class _SupportsDType(Protocol[_DType_co]):
@property
def dtype(self) -> _DType_co: ...
Expand Down
2 changes: 2 additions & 0 deletions numpy/_typing/_nested_sequence.py
Expand Up @@ -8,13 +8,15 @@
overload,
TypeVar,
Protocol,
runtime_checkable,
)

__all__ = ["_NestedSequence"]

_T_co = TypeVar("_T_co", covariant=True)


@runtime_checkable
class _NestedSequence(Protocol[_T_co]):
"""A protocol for representing nested sequences.
Expand Down
33 changes: 32 additions & 1 deletion numpy/typing/tests/test_runtime.py
Expand Up @@ -3,11 +3,19 @@
from __future__ import annotations

import sys
from typing import get_type_hints, Union, NamedTuple, get_args, get_origin
from typing import (
get_type_hints,
Union,
NamedTuple,
get_args,
get_origin,
Any,
)

import pytest
import numpy as np
import numpy.typing as npt
import numpy._typing as _npt


class TypeTup(NamedTuple):
Expand Down Expand Up @@ -80,3 +88,26 @@ def test_keys() -> None:
keys = TYPES.keys()
ref = set(npt.__all__)
assert keys == ref


PROTOCOLS: dict[str, tuple[type[Any], object]] = {
"_SupportsDType": (_npt._SupportsDType, np.int64(1)),
"_SupportsArray": (_npt._SupportsArray, np.arange(10)),
"_SupportsArrayFunc": (_npt._SupportsArrayFunc, np.arange(10)),
"_NestedSequence": (_npt._NestedSequence, [1]),
}


@pytest.mark.parametrize("cls,obj", PROTOCOLS.values(), ids=PROTOCOLS.keys())
class TestRuntimeProtocol:
def test_isinstance(self, cls: type[Any], obj: object) -> None:
assert isinstance(obj, cls)
assert not isinstance(None, cls)

def test_issubclass(self, cls: type[Any], obj: object) -> None:
if cls is _npt._SupportsDType:
pytest.xfail(
"Protocols with non-method members don't support issubclass()"
)
assert issubclass(type(obj), cls)
assert not issubclass(type(None), cls)

0 comments on commit 5490d87

Please sign in to comment.