Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Allow integer inputs for pow-related functions in array_api #20807

Merged
merged 1 commit into from Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 6 additions & 8 deletions numpy/array_api/_array_object.py
Expand Up @@ -655,15 +655,13 @@ def __pos__(self: Array, /) -> Array:
res = self._array.__pos__()
return self.__class__._new(res)

# PEP 484 requires int to be a subtype of float, but __pow__ should not
# accept int.
def __pow__(self: Array, other: Union[float, Array], /) -> Array:
def __pow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __pow__.
"""
from ._elementwise_functions import pow

other = self._check_allowed_dtypes(other, "floating-point", "__pow__")
other = self._check_allowed_dtypes(other, "numeric", "__pow__")
if other is NotImplemented:
return other
# Note: NumPy's __pow__ does not follow type promotion rules for 0-d
Expand Down Expand Up @@ -913,23 +911,23 @@ def __ror__(self: Array, other: Union[int, bool, Array], /) -> Array:
res = self._array.__ror__(other._array)
return self.__class__._new(res)

def __ipow__(self: Array, other: Union[float, Array], /) -> Array:
def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __ipow__.
"""
other = self._check_allowed_dtypes(other, "floating-point", "__ipow__")
other = self._check_allowed_dtypes(other, "numeric", "__ipow__")
if other is NotImplemented:
return other
self._array.__ipow__(other._array)
return self

def __rpow__(self: Array, other: Union[float, Array], /) -> Array:
def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __rpow__.
"""
from ._elementwise_functions import pow

other = self._check_allowed_dtypes(other, "floating-point", "__rpow__")
other = self._check_allowed_dtypes(other, "numeric", "__rpow__")
if other is NotImplemented:
return other
# Note: NumPy's __pow__ does not follow the spec type promotion rules
Expand Down
4 changes: 2 additions & 2 deletions numpy/array_api/_elementwise_functions.py
Expand Up @@ -591,8 +591,8 @@ def pow(x1: Array, x2: Array, /) -> Array:

See its docstring for more information.
"""
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
raise TypeError("Only floating-point dtypes are allowed in pow")
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in pow")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
Expand Down
2 changes: 1 addition & 1 deletion numpy/array_api/tests/test_array_object.py
Expand Up @@ -98,7 +98,7 @@ def test_operators():
"__mul__": "numeric",
"__ne__": "all",
"__or__": "integer_or_boolean",
"__pow__": "floating",
"__pow__": "numeric",
"__rshift__": "integer",
"__sub__": "numeric",
"__truediv__": "floating",
Expand Down
2 changes: 1 addition & 1 deletion numpy/array_api/tests/test_elementwise_functions.py
Expand Up @@ -66,7 +66,7 @@ def test_function_types():
"negative": "numeric",
"not_equal": "all",
"positive": "numeric",
"pow": "floating-point",
"pow": "numeric",
"remainder": "numeric",
"round": "numeric",
"sign": "numeric",
Expand Down