diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index e9b024891884..cd413c58f685 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -3653,11 +3653,12 @@ class RankWarning(UserWarning): ... class TooHardError(RuntimeError): ... class AxisError(ValueError, IndexError): - axis: int + axis: None | int ndim: None | int - def __init__( - self, axis: int, ndim: None | int = ..., msg_prefix: None | str = ... - ) -> None: ... + @overload + def __init__(self, axis: str, ndim: None = ..., msg_prefix: None = ...) -> None: ... + @overload + def __init__(self, axis: int, ndim: int, msg_prefix: None | str = ...) -> None: ... _CallType = TypeVar("_CallType", bound=Union[_ErrFunc, _SupportsWrite]) diff --git a/numpy/core/_exceptions.py b/numpy/core/_exceptions.py index 98c00737efa6..18e69c829b37 100644 --- a/numpy/core/_exceptions.py +++ b/numpy/core/_exceptions.py @@ -133,8 +133,9 @@ class AxisError(ValueError, IndexError): Parameters ---------- - axis : int - The out of bounds axis. + axis : int or str + The out of bounds axis or a custom exception message. + If an axis is provided, then `ndim` should be specified as well. ndim : int, optional The number of array dimensions. msg_prefix : str, optional @@ -142,11 +143,12 @@ class AxisError(ValueError, IndexError): Attributes ---------- - axis : int - The out of bounds axis. + axis : int, optional + The out of bounds axis or ``None`` if a custom exception + message was provided. ndim : int, optional - The number of array dimensions. - Defaults to ``None`` if unspecified. + The number of array dimensions or ``None`` if a custom exception + message was provided. Examples -------- @@ -156,17 +158,27 @@ class AxisError(ValueError, IndexError): ... numpy.AxisError: axis 1 is out of bounds for array of dimension 1 + The class constructor generally takes the axis and arrays' + dimensionality as arguments: + + >>> np.AxisError(2, 1, prefix='error') + numpy.AxisError('error: axis 2 is out of bounds for array of dimension 1') + + Alternativelly, a custom exception message can be passed: + + >>> np.AxisError('Custom error message') + numpy.AxisError('Custom error message') + """ __slots__ = ("axis", "ndim") def __init__(self, axis, ndim=None, msg_prefix=None): - self.axis = axis - self.ndim = ndim - # single-argument form just delegates to base class if ndim is None and msg_prefix is None: msg = axis + self.axis = None + self.ndim = None # do the string formatting here, to save work in the C code else: @@ -174,6 +186,8 @@ def __init__(self, axis, ndim=None, msg_prefix=None): .format(axis, ndim)) if msg_prefix is not None: msg = "{}: {}".format(msg_prefix, msg) + self.axis = axis + self.ndim = ndim super().__init__(msg) diff --git a/numpy/core/tests/test__exceptions.py b/numpy/core/tests/test__exceptions.py index 49732204da6c..43fad42812ff 100644 --- a/numpy/core/tests/test__exceptions.py +++ b/numpy/core/tests/test__exceptions.py @@ -62,15 +62,18 @@ def test_pickling(self): @pytest.mark.parametrize("args", [ (2, 1, None), (2, 1, "test_prefix"), - (2, None, None), - (2, None, "test_prefix") + ("test message",), ]) class TestAxisError: def test_attr(self, args): """Validate attribute types.""" exc = np.AxisError(*args) - assert exc.axis == args[0] - assert exc.ndim == args[1] + if len(args) == 1: + assert exc.axis is None + assert exc.ndim is None + else: + assert exc.axis == axis + assert exc.ndim == ndim def test_pickling(self, args): """Test that `AxisError` can be pickled."""