Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add strict parameter to assert_array_equal. Fixes #9542 #21595

Merged
merged 12 commits into from Jun 24, 2022
5 changes: 5 additions & 0 deletions doc/release/upcoming_changes/21595.new_feature.rst
@@ -0,0 +1,5 @@
``strict`` option for `testing.assert_array_equal`
----------------------------------------------------
WarrenWeckesser marked this conversation as resolved.
Show resolved Hide resolved
The ``strict`` option is now available for `testing.assert_array_equal`.
Setting ``strict=True`` will disable the broadcasting behaviour for scalars and
ensure that input arrays have the same data type.
45 changes: 40 additions & 5 deletions numpy/testing/_private/utils.py
Expand Up @@ -699,7 +699,8 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):


def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
precision=6, equal_nan=True, equal_inf=True):
precision=6, equal_nan=True, equal_inf=True,
*, strict=False):
__tracebackhide__ = True # Hide traceback for py.test
from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_

Expand Down Expand Up @@ -753,7 +754,10 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
return y_id

try:
cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
if strict:
cond = x.shape == y.shape and x.dtype == y.dtype
else:
cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
if not cond:
msg = build_err_msg([x, y],
err_msg
Expand Down Expand Up @@ -852,7 +856,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
raise ValueError(msg)


def assert_array_equal(x, y, err_msg='', verbose=True):
def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False):
"""
Raises an AssertionError if two array_like objects are not equal.

Expand All @@ -876,6 +880,10 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
The error message to be printed in case of failure.
verbose : bool, optional
If True, the conflicting values are appended to the error message.
strict : bool, optional
If True, raise an assertion when either the shape or the data
WarrenWeckesser marked this conversation as resolved.
Show resolved Hide resolved
type of the array_like objects does not match. The special
handling for scalars mentioned in the Notes section is disabled.

Raises
------
Expand All @@ -892,7 +900,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
-----
When one of `x` and `y` is a scalar and the other is array_like, the
function checks that each element of the array_like object is equal to
the scalar.
the scalar. This behaviour can be disabled with the `strict` parameter.

Examples
--------
Expand Down Expand Up @@ -929,10 +937,37 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
>>> x = np.full((2, 5), fill_value=3)
>>> np.testing.assert_array_equal(x, 3)

Use `strict` to raise an assertion when comparing a scalar with an array:
WarrenWeckesser marked this conversation as resolved.
Show resolved Hide resolved

>>> np.testing.assert_array_equal(x, 3, strict=True)
Traceback (most recent call last):
...
AssertionError:
Arrays are not equal
<BLANKLINE>
(shapes (2, 5), () mismatch)
x: array([[3, 3, 3, 3, 3],
[3, 3, 3, 3, 3]])
y: array(3)

The `strict` parameter also ensures that the array data types match:

>>> x = np.array([2, 2, 2])
>>> y = np.array([2., 2., 2.], dtype=np.float32)
>>> np.testing.assert_array_equal(x, y, strict=True)
Traceback (most recent call last):
...
AssertionError:
Arrays are not equal
<BLANKLINE>
(shapes (3,), (3,) mismatch)
x: array([2, 2, 2])
y: array([2., 2., 2.], dtype=float32)
"""
__tracebackhide__ = True # Hide traceback for py.test
assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
verbose=verbose, header='Arrays are not equal')
verbose=verbose, header='Arrays are not equal',
strict=strict)


def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
Expand Down
4 changes: 4 additions & 0 deletions numpy/testing/_private/utils.pyi
Expand Up @@ -200,13 +200,17 @@ def assert_array_compare(
precision: SupportsIndex = ...,
equal_nan: bool = ...,
equal_inf: bool = ...,
*,
strict: bool = ...
) -> None: ...

def assert_array_equal(
x: ArrayLike,
y: ArrayLike,
err_msg: str = ...,
verbose: bool = ...,
*,
strict: bool = ...
BvB93 marked this conversation as resolved.
Show resolved Hide resolved
) -> None: ...

def assert_array_almost_equal(
Expand Down
37 changes: 37 additions & 0 deletions numpy/testing/tests/test_utils.py
Expand Up @@ -214,6 +214,43 @@ def test_suppress_overflow_warnings(self):
np.array([1, 2, 3], np.float32),
np.array([1, 1e-40, 3], np.float32))

def test_array_vs_scalar_is_equal(self):
"""Test comparing an array with a scalar when all values are equal."""
a = np.array([1., 1., 1.])
b = 1.

self._test_equal(a, b)

def test_array_vs_scalar_not_equal(self):
"""Test comparing an array with a scalar when not all values equal."""
a = np.array([1., 2., 3.])
b = 1.

self._test_not_equal(a, b)

def test_array_vs_scalar_strict(self):
"""Test comparing an array with a scalar with strict option."""
a = np.array([1., 1., 1.])
b = 1.

with pytest.raises(AssertionError):
assert_array_equal(a, b, strict=True)

def test_array_vs_array_strict(self):
"""Test comparing two arrays with strict option."""
a = np.array([1., 1., 1.])
b = np.array([1., 1., 1.])

assert_array_equal(a, b, strict=True)

def test_array_vs_float_array_strict(self):
"""Test comparing two arrays with strict option."""
a = np.array([1, 1, 1])
b = np.array([1., 1., 1.])

with pytest.raises(AssertionError):
assert_array_equal(a, b, strict=True)


class TestBuildErrorMessage:

Expand Down