From 2fbc8fa2424d0c2e4403f259732fbb674db0e229 Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Sun, 2 Jan 2022 14:57:30 +0100 Subject: [PATCH] TYP,MAINT: Allow `ndindex` to accept integer tuples --- numpy/__init__.pyi | 3 +++ numpy/typing/tests/data/fail/index_tricks.pyi | 1 + numpy/typing/tests/data/reveal/index_tricks.pyi | 2 ++ 3 files changed, 6 insertions(+) diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index eb1e81c6ac66..9f3e42a69957 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -3321,6 +3321,9 @@ class ndenumerate(Generic[_ScalarType]): def __iter__(self: _T) -> _T: ... class ndindex: + @overload + def __init__(self, shape: tuple[SupportsIndex, ...], /) -> None: ... + @overload def __init__(self, *shape: SupportsIndex) -> None: ... def __iter__(self: _T) -> _T: ... def __next__(self) -> _Shape: ... diff --git a/numpy/typing/tests/data/fail/index_tricks.pyi b/numpy/typing/tests/data/fail/index_tricks.pyi index c508bf3aeae6..565e81a9ab25 100644 --- a/numpy/typing/tests/data/fail/index_tricks.pyi +++ b/numpy/typing/tests/data/fail/index_tricks.pyi @@ -4,6 +4,7 @@ import numpy as np AR_LIKE_i: List[int] AR_LIKE_f: List[float] +np.ndindex([1, 2, 3]) # E: No overload variant np.unravel_index(AR_LIKE_f, (1, 2, 3)) # E: incompatible type np.ravel_multi_index(AR_LIKE_i, (1, 2, 3), mode="bob") # E: No overload variant np.mgrid[1] # E: Invalid index type diff --git a/numpy/typing/tests/data/reveal/index_tricks.pyi b/numpy/typing/tests/data/reveal/index_tricks.pyi index cee4d8c3e7e6..55c033fe011f 100644 --- a/numpy/typing/tests/data/reveal/index_tricks.pyi +++ b/numpy/typing/tests/data/reveal/index_tricks.pyi @@ -24,6 +24,8 @@ reveal_type(iter(np.ndenumerate(AR_i8))) # E: Iterator[Tuple[builtins.tuple[bui reveal_type(iter(np.ndenumerate(AR_LIKE_f))) # E: Iterator[Tuple[builtins.tuple[builtins.int], {double}]] reveal_type(iter(np.ndenumerate(AR_LIKE_U))) # E: Iterator[Tuple[builtins.tuple[builtins.int], str_]] +reveal_type(np.ndindex(1, 2, 3)) # E: numpy.ndindex +reveal_type(np.ndindex((1, 2, 3))) # E: numpy.ndindex reveal_type(iter(np.ndindex(1, 2, 3))) # E: Iterator[builtins.tuple[builtins.int]] reveal_type(next(np.ndindex(1, 2, 3))) # E: builtins.tuple[builtins.int]