From a1a0c10eba3825ef5eef02f769797e7ebeeded1a Mon Sep 17 00:00:00 2001 From: Jon Morris Date: Wed, 25 May 2022 11:51:31 +0100 Subject: [PATCH 01/12] Add strict parameter to assert_array_equal. Fixes #9542 --- numpy/testing/_private/utils.py | 23 ++++++++++++++++++----- numpy/testing/tests/test_utils.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index 4a8f42e06ef5..99a8f2135027 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -699,7 +699,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', - precision=6, equal_nan=True, equal_inf=True): + precision=6, equal_nan=True, equal_inf=True, strict=False): __tracebackhide__ = True # Hide traceback for py.test from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_ @@ -753,7 +753,10 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'): return y_id try: - cond = (x.shape == () or y.shape == ()) or x.shape == y.shape + if strict: + cond = x.shape == y.shape + else: + cond = (x.shape == () or y.shape == ()) or x.shape == y.shape if not cond: msg = build_err_msg([x, y], err_msg @@ -852,7 +855,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'): raise ValueError(msg) -def assert_array_equal(x, y, err_msg='', verbose=True): +def assert_array_equal(x, y, err_msg='', verbose=True, strict=False): """ Raises an AssertionError if two array_like objects are not equal. @@ -876,6 +879,8 @@ def assert_array_equal(x, y, err_msg='', verbose=True): The error message to be printed in case of failure. verbose : bool, optional If True, the conflicting values are appended to the error message. + strict : bool, optional + If True, raise an assertion when one of the array_like objects is a scalar. Raises ------ @@ -892,7 +897,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True): ----- When one of `x` and `y` is a scalar and the other is array_like, the function checks that each element of the array_like object is equal to - the scalar. + the scalar. This behaviour can be disabled with the `strict` parameter. Examples -------- @@ -929,10 +934,18 @@ def assert_array_equal(x, y, err_msg='', verbose=True): >>> x = np.full((2, 5), fill_value=3) >>> np.testing.assert_array_equal(x, 3) + Use `strict` to raise an assertion when comparing a scalar with an array: + + >>> np.testing.assert_array_equal(x, 3, strict=True) + + AssertionError: + Arrays are not equal + (shapes (2, 5), () mismatch) + """ __tracebackhide__ = True # Hide traceback for py.test assert_array_compare(operator.__eq__, x, y, err_msg=err_msg, - verbose=verbose, header='Arrays are not equal') + verbose=verbose, header='Arrays are not equal', strict=strict) def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 49eeecc8ee03..9b979de1bcd9 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -214,6 +214,35 @@ def test_suppress_overflow_warnings(self): np.array([1, 2, 3], np.float32), np.array([1, 1e-40, 3], np.float32)) + def test_array_vs_scalar_is_equal(self): + """Test comparing an array with a scalar when all values are equal.""" + a = np.array([1., 1., 1.]) + b = 1. + + self._test_equal(a, b) + + def test_array_vs_scalar_not_equal(self): + """Test comparing an array with a scalar when not all values are equal.""" + a = np.array([1., 2., 3.]) + b = 1. + + self._test_not_equal(a, b) + + def test_array_vs_scalar_strict(self): + """Test comparing an array with a scalar with strict option.""" + a = np.array([1., 1., 1.]) + b = 1. + + with pytest.raises(AssertionError): + self._assert_func(a, b, strict=True) + + def test_array_vs_array_strict(self): + """Test comparing two arrays with strict option.""" + a = np.array([1., 1., 1.]) + b = np.array([1., 1., 1.]) + + self._assert_func(a, b, strict=True) + class TestBuildErrorMessage: From ac1869d68cb5a23479515a45593efe489a3e6582 Mon Sep 17 00:00:00 2001 From: Jon Morris Date: Wed, 25 May 2022 12:00:12 +0100 Subject: [PATCH 02/12] Lint fix --- numpy/testing/tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 9b979de1bcd9..79e58c0b50c2 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -222,7 +222,7 @@ def test_array_vs_scalar_is_equal(self): self._test_equal(a, b) def test_array_vs_scalar_not_equal(self): - """Test comparing an array with a scalar when not all values are equal.""" + """Test comparing an array with a scalar when not all values equal.""" a = np.array([1., 2., 3.]) b = 1. From 5eef7b479e31d35cb5c83af96ce6d152ed9fb41e Mon Sep 17 00:00:00 2001 From: Jon Morris Date: Wed, 25 May 2022 12:15:14 +0100 Subject: [PATCH 03/12] Lint fix --- numpy/testing/_private/utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index 99a8f2135027..aaaedec4f5ea 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -699,7 +699,8 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', - precision=6, equal_nan=True, equal_inf=True, strict=False): + precision=6, equal_nan=True, equal_inf=True, + strict=False): __tracebackhide__ = True # Hide traceback for py.test from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_ @@ -880,7 +881,8 @@ def assert_array_equal(x, y, err_msg='', verbose=True, strict=False): verbose : bool, optional If True, the conflicting values are appended to the error message. strict : bool, optional - If True, raise an assertion when one of the array_like objects is a scalar. + If True, raise an assertion when one of the array_like objects is a + scalar. Raises ------ @@ -945,7 +947,8 @@ def assert_array_equal(x, y, err_msg='', verbose=True, strict=False): """ __tracebackhide__ = True # Hide traceback for py.test assert_array_compare(operator.__eq__, x, y, err_msg=err_msg, - verbose=verbose, header='Arrays are not equal', strict=strict) + verbose=verbose, header='Arrays are not equal', + strict=strict) def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): From 982ca996acefc3f13b654bb724f7bd25027f7b0e Mon Sep 17 00:00:00 2001 From: Jon Morris Date: Wed, 25 May 2022 13:49:17 +0100 Subject: [PATCH 04/12] Call assert_array_equal directly --- numpy/testing/tests/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 79e58c0b50c2..47c12ce81e41 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -234,14 +234,14 @@ def test_array_vs_scalar_strict(self): b = 1. with pytest.raises(AssertionError): - self._assert_func(a, b, strict=True) + assert_array_equal(a, b, strict=True) def test_array_vs_array_strict(self): """Test comparing two arrays with strict option.""" a = np.array([1., 1., 1.]) b = np.array([1., 1., 1.]) - self._assert_func(a, b, strict=True) + assert_array_equal(a, b, strict=True) class TestBuildErrorMessage: From 6bc1abf0eb069136a59de337323df0736687efe9 Mon Sep 17 00:00:00 2001 From: Jon Morris Date: Wed, 25 May 2022 17:17:10 +0100 Subject: [PATCH 05/12] Update doc example so refguide check passes --- numpy/testing/_private/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index aaaedec4f5ea..ae285bbfd9e1 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -939,10 +939,15 @@ def assert_array_equal(x, y, err_msg='', verbose=True, strict=False): Use `strict` to raise an assertion when comparing a scalar with an array: >>> np.testing.assert_array_equal(x, 3, strict=True) - + Traceback (most recent call last): + ... AssertionError: Arrays are not equal + (shapes (2, 5), () mismatch) + x: array([[3, 3, 3, 3, 3], + [3, 3, 3, 3, 3]]) + y: array(3) """ __tracebackhide__ = True # Hide traceback for py.test From 84dd0ac4afdded1cd1b177b209dcef5c9218bea2 Mon Sep 17 00:00:00 2001 From: Jon Morris Date: Fri, 27 May 2022 20:10:54 +0100 Subject: [PATCH 06/12] Apply suggestions from code review Make new arguments keyword-only Co-authored-by: Bas van Beek <43369155+BvB93@users.noreply.github.com> --- numpy/testing/_private/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index ae285bbfd9e1..ba3aa7dfd672 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -700,7 +700,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', precision=6, equal_nan=True, equal_inf=True, - strict=False): + *, strict=False): __tracebackhide__ = True # Hide traceback for py.test from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_ @@ -856,7 +856,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'): raise ValueError(msg) -def assert_array_equal(x, y, err_msg='', verbose=True, strict=False): +def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False): """ Raises an AssertionError if two array_like objects are not equal. From a9b21fb6b5bebbea1b446a7c926e76dace09bb59 Mon Sep 17 00:00:00 2001 From: Jon Morris Date: Fri, 27 May 2022 20:54:08 +0100 Subject: [PATCH 07/12] Update utils.pyi with new parameter --- numpy/testing/_private/utils.pyi | 2 ++ 1 file changed, 2 insertions(+) diff --git a/numpy/testing/_private/utils.pyi b/numpy/testing/_private/utils.pyi index 0be13b7297e0..ca26ef8a148b 100644 --- a/numpy/testing/_private/utils.pyi +++ b/numpy/testing/_private/utils.pyi @@ -207,6 +207,8 @@ def assert_array_equal( y: ArrayLike, err_msg: str = ..., verbose: bool = ..., + *, + strict: bool = ... ) -> None: ... def assert_array_almost_equal( From 6aaffb6a9340326602732f076c802050c33a2b5e Mon Sep 17 00:00:00 2001 From: Jon Morris Date: Fri, 27 May 2022 20:54:30 +0100 Subject: [PATCH 08/12] Add data type check to strict mode --- numpy/testing/_private/utils.py | 17 +++++++++++++++-- numpy/testing/tests/test_utils.py | 8 ++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index ba3aa7dfd672..59e108116271 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -755,7 +755,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'): try: if strict: - cond = x.shape == y.shape + cond = x.shape == y.shape and x.dtype == y.dtype else: cond = (x.shape == () or y.shape == ()) or x.shape == y.shape if not cond: @@ -882,7 +882,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False): If True, the conflicting values are appended to the error message. strict : bool, optional If True, raise an assertion when one of the array_like objects is a - scalar. + scalar or if `x` and `y` have a different data type. Raises ------ @@ -949,6 +949,19 @@ def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False): [3, 3, 3, 3, 3]]) y: array(3) + The `strict` parameter also ensures that the array data types match: + + >>> x = np.array([2, 2, 2]) + >>> y = np.array([2., 2., 2.], dtype=np.float32) + >>> np.testing.assert_array_equal(x, y, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + + (shapes (3,), (3,) mismatch) + x: array([2, 2, 2]) + y: array([2., 2., 2.], dtype=float32) """ __tracebackhide__ = True # Hide traceback for py.test assert_array_compare(operator.__eq__, x, y, err_msg=err_msg, diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 47c12ce81e41..c82343f0c3ab 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -243,6 +243,14 @@ def test_array_vs_array_strict(self): assert_array_equal(a, b, strict=True) + def test_array_vs_float_array_strict(self): + """Test comparing two arrays with strict option.""" + a = np.array([1, 1, 1]) + b = np.array([1., 1., 1.]) + + with pytest.raises(AssertionError): + assert_array_equal(a, b, strict=True) + class TestBuildErrorMessage: From 018537929560d1f77ce9d3cebbd7d756fba4c43f Mon Sep 17 00:00:00 2001 From: Jon Morris Date: Sat, 28 May 2022 10:18:46 +0100 Subject: [PATCH 09/12] Responding to review --- numpy/testing/_private/utils.py | 5 +++-- numpy/testing/_private/utils.pyi | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index 59e108116271..3820365946fb 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -881,8 +881,9 @@ def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False): verbose : bool, optional If True, the conflicting values are appended to the error message. strict : bool, optional - If True, raise an assertion when one of the array_like objects is a - scalar or if `x` and `y` have a different data type. + If True, raise an assertion when either the shape or the data + type of the array_like objects does not match. The special + handling for scalars mentioned in the Notes section is disabled. Raises ------ diff --git a/numpy/testing/_private/utils.pyi b/numpy/testing/_private/utils.pyi index ca26ef8a148b..6e051e914177 100644 --- a/numpy/testing/_private/utils.pyi +++ b/numpy/testing/_private/utils.pyi @@ -200,6 +200,8 @@ def assert_array_compare( precision: SupportsIndex = ..., equal_nan: bool = ..., equal_inf: bool = ..., + *, + strict: bool = ... ) -> None: ... def assert_array_equal( From 5601d964bf484afca9c62b60604764dcd8861d31 Mon Sep 17 00:00:00 2001 From: Jon Morris Date: Sun, 29 May 2022 15:08:34 +0100 Subject: [PATCH 10/12] Add release note --- doc/release/upcoming_changes/21595.new_feature.rst | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 doc/release/upcoming_changes/21595.new_feature.rst diff --git a/doc/release/upcoming_changes/21595.new_feature.rst b/doc/release/upcoming_changes/21595.new_feature.rst new file mode 100644 index 000000000000..c6164969dd3b --- /dev/null +++ b/doc/release/upcoming_changes/21595.new_feature.rst @@ -0,0 +1,5 @@ +``strict`` option for `testing.assert_array_equal` +---------------------------------------------------- +The ``strict`` option is now available for `testing.assert_array_equal`. +Setting ``strict=True`` will disable the broadcasting behaviour for scalars and +ensure that input arrays have the same data type. From ebf9ddf517550cef3ea9fae62ebaaac6488397a7 Mon Sep 17 00:00:00 2001 From: Jon Morris Date: Tue, 21 Jun 2022 19:51:55 +0100 Subject: [PATCH 11/12] Responding to review --- doc/release/upcoming_changes/21595.new_feature.rst | 2 +- numpy/testing/_private/utils.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/doc/release/upcoming_changes/21595.new_feature.rst b/doc/release/upcoming_changes/21595.new_feature.rst index c6164969dd3b..21b2a746f1b7 100644 --- a/doc/release/upcoming_changes/21595.new_feature.rst +++ b/doc/release/upcoming_changes/21595.new_feature.rst @@ -1,5 +1,5 @@ ``strict`` option for `testing.assert_array_equal` ----------------------------------------------------- +-------------------------------------------------- The ``strict`` option is now available for `testing.assert_array_equal`. Setting ``strict=True`` will disable the broadcasting behaviour for scalars and ensure that input arrays have the same data type. diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index 3820365946fb..2fec97d56f28 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -759,9 +759,13 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'): else: cond = (x.shape == () or y.shape == ()) or x.shape == y.shape if not cond: + if x.shape != y.shape: + reason = f'\n(shapes {x.shape}, {y.shape} mismatch)' + else: + reason = f'\n(dtypes {x.dtype}, {y.dtype} mismatch)' msg = build_err_msg([x, y], err_msg - + f'\n(shapes {x.shape}, {y.shape} mismatch)', + + reason, verbose=verbose, header=header, names=('x', 'y'), precision=precision) raise AssertionError(msg) @@ -881,7 +885,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False): verbose : bool, optional If True, the conflicting values are appended to the error message. strict : bool, optional - If True, raise an assertion when either the shape or the data + If True, raise an AssertionError when either the shape or the data type of the array_like objects does not match. The special handling for scalars mentioned in the Notes section is disabled. @@ -937,7 +941,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False): >>> x = np.full((2, 5), fill_value=3) >>> np.testing.assert_array_equal(x, 3) - Use `strict` to raise an assertion when comparing a scalar with an array: + Use `strict` to raise an AssertionError when comparing a scalar with an array: >>> np.testing.assert_array_equal(x, 3, strict=True) Traceback (most recent call last): @@ -960,7 +964,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False): AssertionError: Arrays are not equal - (shapes (3,), (3,) mismatch) + (dtypes int64, float32 mismatch) x: array([2, 2, 2]) y: array([2., 2., 2.], dtype=float32) """ From 61ecc910474a3d62130a66872cb1005f45e5a716 Mon Sep 17 00:00:00 2001 From: Jon Morris Date: Tue, 21 Jun 2022 20:08:48 +0100 Subject: [PATCH 12/12] Lint fix --- numpy/testing/_private/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index 2fec97d56f28..c8fa5c2e0484 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -941,7 +941,8 @@ def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False): >>> x = np.full((2, 5), fill_value=3) >>> np.testing.assert_array_equal(x, 3) - Use `strict` to raise an AssertionError when comparing a scalar with an array: + Use `strict` to raise an AssertionError when comparing a scalar with an + array: >>> np.testing.assert_array_equal(x, 3, strict=True) Traceback (most recent call last):