diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 391e9f575e60..00ffe7162158 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -655,15 +655,13 @@ def __pos__(self: Array, /) -> Array: res = self._array.__pos__() return self.__class__._new(res) - # PEP 484 requires int to be a subtype of float, but __pow__ should not - # accept int. - def __pow__(self: Array, other: Union[float, Array], /) -> Array: + def __pow__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __pow__. """ from ._elementwise_functions import pow - other = self._check_allowed_dtypes(other, "floating-point", "__pow__") + other = self._check_allowed_dtypes(other, "numeric", "__pow__") if other is NotImplemented: return other # Note: NumPy's __pow__ does not follow type promotion rules for 0-d @@ -913,23 +911,23 @@ def __ror__(self: Array, other: Union[int, bool, Array], /) -> Array: res = self._array.__ror__(other._array) return self.__class__._new(res) - def __ipow__(self: Array, other: Union[float, Array], /) -> Array: + def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __ipow__. """ - other = self._check_allowed_dtypes(other, "floating-point", "__ipow__") + other = self._check_allowed_dtypes(other, "numeric", "__ipow__") if other is NotImplemented: return other self._array.__ipow__(other._array) return self - def __rpow__(self: Array, other: Union[float, Array], /) -> Array: + def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __rpow__. """ from ._elementwise_functions import pow - other = self._check_allowed_dtypes(other, "floating-point", "__rpow__") + other = self._check_allowed_dtypes(other, "numeric", "__rpow__") if other is NotImplemented: return other # Note: NumPy's __pow__ does not follow the spec type promotion rules diff --git a/numpy/array_api/_elementwise_functions.py b/numpy/array_api/_elementwise_functions.py index 4408fe833b4c..c758a09447a0 100644 --- a/numpy/array_api/_elementwise_functions.py +++ b/numpy/array_api/_elementwise_functions.py @@ -591,8 +591,8 @@ def pow(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in pow") + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in pow") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index b980bacca493..1fe1dfddf80c 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -98,7 +98,7 @@ def test_operators(): "__mul__": "numeric", "__ne__": "all", "__or__": "integer_or_boolean", - "__pow__": "floating", + "__pow__": "numeric", "__rshift__": "integer", "__sub__": "numeric", "__truediv__": "floating", diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py index a9274aec9278..b2fb44e766f8 100644 --- a/numpy/array_api/tests/test_elementwise_functions.py +++ b/numpy/array_api/tests/test_elementwise_functions.py @@ -66,7 +66,7 @@ def test_function_types(): "negative": "numeric", "not_equal": "all", "positive": "numeric", - "pow": "floating-point", + "pow": "numeric", "remainder": "numeric", "round": "numeric", "sign": "numeric",