Skip to content

Commit

Permalink
ENH: Add strict parameter to assert_array_equal. (#21595)
Browse files Browse the repository at this point in the history
Fixes #9542

Co-authored-by: Bas van Beek <43369155+BvB93@users.noreply.github.com>
  • Loading branch information
jontwo and BvB93 committed Jun 24, 2022
1 parent 019c8c9 commit cafec60
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 6 deletions.
5 changes: 5 additions & 0 deletions doc/release/upcoming_changes/21595.new_feature.rst
@@ -0,0 +1,5 @@
``strict`` option for `testing.assert_array_equal`
--------------------------------------------------
The ``strict`` option is now available for `testing.assert_array_equal`.
Setting ``strict=True`` will disable the broadcasting behaviour for scalars and
ensure that input arrays have the same data type.
52 changes: 46 additions & 6 deletions numpy/testing/_private/utils.py
Expand Up @@ -699,7 +699,8 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):


def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
precision=6, equal_nan=True, equal_inf=True):
precision=6, equal_nan=True, equal_inf=True,
*, strict=False):
__tracebackhide__ = True # Hide traceback for py.test
from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_

Expand Down Expand Up @@ -753,11 +754,18 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
return y_id

try:
cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
if strict:
cond = x.shape == y.shape and x.dtype == y.dtype
else:
cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
if not cond:
if x.shape != y.shape:
reason = f'\n(shapes {x.shape}, {y.shape} mismatch)'
else:
reason = f'\n(dtypes {x.dtype}, {y.dtype} mismatch)'
msg = build_err_msg([x, y],
err_msg
+ f'\n(shapes {x.shape}, {y.shape} mismatch)',
+ reason,
verbose=verbose, header=header,
names=('x', 'y'), precision=precision)
raise AssertionError(msg)
Expand Down Expand Up @@ -852,7 +860,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
raise ValueError(msg)


def assert_array_equal(x, y, err_msg='', verbose=True):
def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False):
"""
Raises an AssertionError if two array_like objects are not equal.
Expand All @@ -876,6 +884,10 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
The error message to be printed in case of failure.
verbose : bool, optional
If True, the conflicting values are appended to the error message.
strict : bool, optional
If True, raise an AssertionError when either the shape or the data
type of the array_like objects does not match. The special
handling for scalars mentioned in the Notes section is disabled.
Raises
------
Expand All @@ -892,7 +904,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
-----
When one of `x` and `y` is a scalar and the other is array_like, the
function checks that each element of the array_like object is equal to
the scalar.
the scalar. This behaviour can be disabled with the `strict` parameter.
Examples
--------
Expand Down Expand Up @@ -929,10 +941,38 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
>>> x = np.full((2, 5), fill_value=3)
>>> np.testing.assert_array_equal(x, 3)
Use `strict` to raise an AssertionError when comparing a scalar with an
array:
>>> np.testing.assert_array_equal(x, 3, strict=True)
Traceback (most recent call last):
...
AssertionError:
Arrays are not equal
<BLANKLINE>
(shapes (2, 5), () mismatch)
x: array([[3, 3, 3, 3, 3],
[3, 3, 3, 3, 3]])
y: array(3)
The `strict` parameter also ensures that the array data types match:
>>> x = np.array([2, 2, 2])
>>> y = np.array([2., 2., 2.], dtype=np.float32)
>>> np.testing.assert_array_equal(x, y, strict=True)
Traceback (most recent call last):
...
AssertionError:
Arrays are not equal
<BLANKLINE>
(dtypes int64, float32 mismatch)
x: array([2, 2, 2])
y: array([2., 2., 2.], dtype=float32)
"""
__tracebackhide__ = True # Hide traceback for py.test
assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
verbose=verbose, header='Arrays are not equal')
verbose=verbose, header='Arrays are not equal',
strict=strict)


def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
Expand Down
4 changes: 4 additions & 0 deletions numpy/testing/_private/utils.pyi
Expand Up @@ -200,13 +200,17 @@ def assert_array_compare(
precision: SupportsIndex = ...,
equal_nan: bool = ...,
equal_inf: bool = ...,
*,
strict: bool = ...
) -> None: ...

def assert_array_equal(
x: ArrayLike,
y: ArrayLike,
err_msg: str = ...,
verbose: bool = ...,
*,
strict: bool = ...
) -> None: ...

def assert_array_almost_equal(
Expand Down
37 changes: 37 additions & 0 deletions numpy/testing/tests/test_utils.py
Expand Up @@ -214,6 +214,43 @@ def test_suppress_overflow_warnings(self):
np.array([1, 2, 3], np.float32),
np.array([1, 1e-40, 3], np.float32))

def test_array_vs_scalar_is_equal(self):
"""Test comparing an array with a scalar when all values are equal."""
a = np.array([1., 1., 1.])
b = 1.

self._test_equal(a, b)

def test_array_vs_scalar_not_equal(self):
"""Test comparing an array with a scalar when not all values equal."""
a = np.array([1., 2., 3.])
b = 1.

self._test_not_equal(a, b)

def test_array_vs_scalar_strict(self):
"""Test comparing an array with a scalar with strict option."""
a = np.array([1., 1., 1.])
b = 1.

with pytest.raises(AssertionError):
assert_array_equal(a, b, strict=True)

def test_array_vs_array_strict(self):
"""Test comparing two arrays with strict option."""
a = np.array([1., 1., 1.])
b = np.array([1., 1., 1.])

assert_array_equal(a, b, strict=True)

def test_array_vs_float_array_strict(self):
"""Test comparing two arrays with strict option."""
a = np.array([1, 1, 1])
b = np.array([1., 1., 1.])

with pytest.raises(AssertionError):
assert_array_equal(a, b, strict=True)


class TestBuildErrorMessage:

Expand Down

0 comments on commit cafec60

Please sign in to comment.