diff --git a/doc/release/upcoming_changes/21595.new_feature.rst b/doc/release/upcoming_changes/21595.new_feature.rst new file mode 100644 index 000000000000..21b2a746f1b7 --- /dev/null +++ b/doc/release/upcoming_changes/21595.new_feature.rst @@ -0,0 +1,5 @@ +``strict`` option for `testing.assert_array_equal` +-------------------------------------------------- +The ``strict`` option is now available for `testing.assert_array_equal`. +Setting ``strict=True`` will disable the broadcasting behaviour for scalars and +ensure that input arrays have the same data type. diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index e4f8b98924a2..f60ca6922192 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -699,7 +699,8 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', - precision=6, equal_nan=True, equal_inf=True): + precision=6, equal_nan=True, equal_inf=True, + *, strict=False): __tracebackhide__ = True # Hide traceback for py.test from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_ @@ -753,11 +754,18 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'): return y_id try: - cond = (x.shape == () or y.shape == ()) or x.shape == y.shape + if strict: + cond = x.shape == y.shape and x.dtype == y.dtype + else: + cond = (x.shape == () or y.shape == ()) or x.shape == y.shape if not cond: + if x.shape != y.shape: + reason = f'\n(shapes {x.shape}, {y.shape} mismatch)' + else: + reason = f'\n(dtypes {x.dtype}, {y.dtype} mismatch)' msg = build_err_msg([x, y], err_msg - + f'\n(shapes {x.shape}, {y.shape} mismatch)', + + reason, verbose=verbose, header=header, names=('x', 'y'), precision=precision) raise AssertionError(msg) @@ -852,7 +860,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'): raise ValueError(msg) -def assert_array_equal(x, y, err_msg='', verbose=True): +def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False): """ Raises an AssertionError if two array_like objects are not equal. @@ -876,6 +884,10 @@ def assert_array_equal(x, y, err_msg='', verbose=True): The error message to be printed in case of failure. verbose : bool, optional If True, the conflicting values are appended to the error message. + strict : bool, optional + If True, raise an AssertionError when either the shape or the data + type of the array_like objects does not match. The special + handling for scalars mentioned in the Notes section is disabled. Raises ------ @@ -892,7 +904,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True): ----- When one of `x` and `y` is a scalar and the other is array_like, the function checks that each element of the array_like object is equal to - the scalar. + the scalar. This behaviour can be disabled with the `strict` parameter. Examples -------- @@ -929,10 +941,38 @@ def assert_array_equal(x, y, err_msg='', verbose=True): >>> x = np.full((2, 5), fill_value=3) >>> np.testing.assert_array_equal(x, 3) + Use `strict` to raise an AssertionError when comparing a scalar with an + array: + + >>> np.testing.assert_array_equal(x, 3, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + + (shapes (2, 5), () mismatch) + x: array([[3, 3, 3, 3, 3], + [3, 3, 3, 3, 3]]) + y: array(3) + + The `strict` parameter also ensures that the array data types match: + + >>> x = np.array([2, 2, 2]) + >>> y = np.array([2., 2., 2.], dtype=np.float32) + >>> np.testing.assert_array_equal(x, y, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + + (dtypes int64, float32 mismatch) + x: array([2, 2, 2]) + y: array([2., 2., 2.], dtype=float32) """ __tracebackhide__ = True # Hide traceback for py.test assert_array_compare(operator.__eq__, x, y, err_msg=err_msg, - verbose=verbose, header='Arrays are not equal') + verbose=verbose, header='Arrays are not equal', + strict=strict) def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): diff --git a/numpy/testing/_private/utils.pyi b/numpy/testing/_private/utils.pyi index 0be13b7297e0..6e051e914177 100644 --- a/numpy/testing/_private/utils.pyi +++ b/numpy/testing/_private/utils.pyi @@ -200,6 +200,8 @@ def assert_array_compare( precision: SupportsIndex = ..., equal_nan: bool = ..., equal_inf: bool = ..., + *, + strict: bool = ... ) -> None: ... def assert_array_equal( @@ -207,6 +209,8 @@ def assert_array_equal( y: ArrayLike, err_msg: str = ..., verbose: bool = ..., + *, + strict: bool = ... ) -> None: ... def assert_array_almost_equal( diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 49eeecc8ee03..c82343f0c3ab 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -214,6 +214,43 @@ def test_suppress_overflow_warnings(self): np.array([1, 2, 3], np.float32), np.array([1, 1e-40, 3], np.float32)) + def test_array_vs_scalar_is_equal(self): + """Test comparing an array with a scalar when all values are equal.""" + a = np.array([1., 1., 1.]) + b = 1. + + self._test_equal(a, b) + + def test_array_vs_scalar_not_equal(self): + """Test comparing an array with a scalar when not all values equal.""" + a = np.array([1., 2., 3.]) + b = 1. + + self._test_not_equal(a, b) + + def test_array_vs_scalar_strict(self): + """Test comparing an array with a scalar with strict option.""" + a = np.array([1., 1., 1.]) + b = 1. + + with pytest.raises(AssertionError): + assert_array_equal(a, b, strict=True) + + def test_array_vs_array_strict(self): + """Test comparing two arrays with strict option.""" + a = np.array([1., 1., 1.]) + b = np.array([1., 1., 1.]) + + assert_array_equal(a, b, strict=True) + + def test_array_vs_float_array_strict(self): + """Test comparing two arrays with strict option.""" + a = np.array([1, 1, 1]) + b = np.array([1., 1., 1.]) + + with pytest.raises(AssertionError): + assert_array_equal(a, b, strict=True) + class TestBuildErrorMessage: