Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add parameter strict to assert_equal #24770

Merged
merged 1 commit into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/release/upcoming_changes/24770.new_feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
``strict`` option for `testing.assert_equal`
--------------------------------------------
The ``strict`` option is now available for `testing.assert_equal`.
Setting ``strict=True`` will disable the broadcasting behaviour for scalars
and ensure that input arrays have the same data type.
64 changes: 57 additions & 7 deletions numpy/testing/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,18 +209,14 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:',
return '\n'.join(msg)


def assert_equal(actual, desired, err_msg='', verbose=True):
def assert_equal(actual, desired, err_msg='', verbose=True, *, strict=False):
"""
Raises an AssertionError if two objects are not equal.

Given two objects (scalars, lists, tuples, dictionaries or numpy arrays),
check that all elements of these objects are equal. An exception is raised
at the first conflicting values.

When one of `actual` and `desired` is a scalar and the other is array_like,
the function checks that each element of the array_like object is equal to
the scalar.

This function handles NaN comparisons as if NaN was a "normal" number.
That is, AssertionError is not raised if both objects have NaNs in the same
positions. This is in contrast to the IEEE standard on NaNs, which says
Expand All @@ -236,15 +232,34 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
The error message to be printed in case of failure.
verbose : bool, optional
If True, the conflicting values are appended to the error message.
strict : bool, optional
If True and either of the `actual` and `desired` arguments is an array,
raise an ``AssertionError`` when either the shape or the data type of
the arguments does not match. If neither argument is an array, this
parameter has no effect.

.. versionadded:: 2.0.0

Raises
------
AssertionError
If actual and desired are not equal.

See Also
--------
assert_allclose
assert_array_almost_equal_nulp,
assert_array_max_ulp,

Notes
-----
By default, when one of `actual` and `desired` is a scalar and the other is
an array, the function checks that each element of the array is equal to
the scalar. This behaviour can be disabled by setting ``strict==True``.

Examples
--------
>>> np.testing.assert_equal([4,5], [4,6])
>>> np.testing.assert_equal([4, 5], [4, 6])
Traceback (most recent call last):
...
AssertionError:
Expand All @@ -258,6 +273,40 @@ def assert_equal(actual, desired, err_msg='', verbose=True):

>>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan])

As mentioned in the Notes section, `assert_equal` has special
handling for scalars when one of the arguments is an array.
Here, the test checks that each value in `x` is 3:

>>> x = np.full((2, 5), fill_value=3)
>>> np.testing.assert_equal(x, 3)

Use `strict` to raise an AssertionError when comparing a scalar with an
array of a different shape:

>>> np.testing.assert_equal(x, 3, strict=True)
Traceback (most recent call last):
...
AssertionError:
Arrays are not equal
<BLANKLINE>
(shapes (2, 5), () mismatch)
x: array([[3, 3, 3, 3, 3],
[3, 3, 3, 3, 3]])
y: array(3)

The `strict` parameter also ensures that the array data types match:

>>> x = np.array([2, 2, 2])
>>> y = np.array([2., 2., 2.], dtype=np.float32)
>>> np.testing.assert_equal(x, y, strict=True)
Traceback (most recent call last):
...
AssertionError:
Arrays are not equal
<BLANKLINE>
(dtypes int64, float32 mismatch)
x: array([2, 2, 2])
y: array([2., 2., 2.], dtype=float32)
"""
__tracebackhide__ = True # Hide traceback for py.test
if isinstance(desired, dict):
Expand All @@ -279,7 +328,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
from numpy.core import ndarray, isscalar, signbit
from numpy import iscomplexobj, real, imag
if isinstance(actual, ndarray) or isinstance(desired, ndarray):
return assert_array_equal(actual, desired, err_msg, verbose)
return assert_array_equal(actual, desired, err_msg, verbose,
strict=strict)
msg = build_err_msg([actual, desired], err_msg, verbose=verbose)

# Handle complex numbers: separate into real/imag to handle
Expand Down
2 changes: 2 additions & 0 deletions numpy/testing/_private/utils.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def assert_equal(
desired: object,
err_msg: str = ...,
verbose: bool = ...,
*,
strict: bool = ...
) -> None: ...

def print_assert_equal(
Expand Down
6 changes: 3 additions & 3 deletions numpy/testing/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,22 +232,22 @@ def test_array_vs_scalar_strict(self):
b = 1.

with pytest.raises(AssertionError):
assert_array_equal(a, b, strict=True)
self._assert_func(a, b, strict=True)
Copy link
Contributor Author

@mdhaber mdhaber Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TestEqual, the class used to test assert_equal, inherits from this class. This class is TestArrayEqual, which is used to test assert_array_equal.)

Before this PR, these three tests of the strict parameter would run twice - once for TestArrayEqual, and once for TestEqual - but assert_array_equal was tested both times. Now, these will test the appropriate assertion function each time they run.

These were the only tests of the strict parameter added when the strict parameter was added to assert_array_equal in gh-21595. If they were sufficient there, hopefully enabling them for assert_equal is sufficient here.


def test_array_vs_array_strict(self):
"""Test comparing two arrays with strict option."""
a = np.array([1., 1., 1.])
b = np.array([1., 1., 1.])

assert_array_equal(a, b, strict=True)
self._assert_func(a, b, strict=True)

def test_array_vs_float_array_strict(self):
"""Test comparing two arrays with strict option."""
a = np.array([1, 1, 1])
b = np.array([1., 1., 1.])

with pytest.raises(AssertionError):
assert_array_equal(a, b, strict=True)
self._assert_func(a, b, strict=True)


class TestBuildErrorMessage:
Expand Down