Skip to content

Commit

Permalink
TST: Add test for comparison functions
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaelcfsousa committed May 20, 2022
1 parent 2336869 commit 0bb6b36
Showing 1 changed file with 76 additions and 0 deletions.
76 changes: 76 additions & 0 deletions numpy/core/tests/test_umath.py
Expand Up @@ -185,6 +185,82 @@ def __array_wrap__(self, arr, context):


class TestComparisons:
@pytest.mark.parametrize('dtype', np.sctypes['uint'] + np.sctypes['int'] +
np.sctypes['float'] + [np.bool_])
def test_comparison_functions(self, dtype):
# Initialize input arrays
if dtype == np.bool_:
a = np.random.choice(a=[False, True], size=1000)
b = np.random.choice(a=[False, True], size=1000)
scalar = True
else:
a = np.random.randint(low=1, high=10, size=1000).astype(dtype)
b = np.random.randint(low=1, high=10, size=1000).astype(dtype)
scalar = 5
scalar_np = np.dtype(dtype).type(scalar)
a_lst = a.tolist()
b_lst = b.tolist()

# (Binary) Comparison (x1=array, x2=array)
lt_b = np.less(a, b)
le_b = np.less_equal(a, b)
gt_b = np.greater(a, b)
ge_b = np.greater_equal(a, b)
eq_b = np.equal(a, b)
ne_b = np.not_equal(a, b)
lt_b_lst = [x < y for x, y in zip(a_lst, b_lst)]
le_b_lst = [x <= y for x, y in zip(a_lst, b_lst)]
gt_b_lst = [x > y for x, y in zip(a_lst, b_lst)]
ge_b_lst = [x >= y for x, y in zip(a_lst, b_lst)]
eq_b_lst = [x == y for x, y in zip(a_lst, b_lst)]
ne_b_lst = [x != y for x, y in zip(a_lst, b_lst)]

# (Scalar1) Comparison (x1=scalar, x2=array)
lt_s1 = np.less(scalar_np, b)
le_s1 = np.less_equal(scalar_np, b)
gt_s1 = np.greater(scalar_np, b)
ge_s1 = np.greater_equal(scalar_np, b)
eq_s1 = np.equal(scalar_np, b)
ne_s1 = np.not_equal(scalar_np, b)
lt_s1_lst = [scalar < x for x in b_lst]
le_s1_lst = [scalar <= x for x in b_lst]
gt_s1_lst = [scalar > x for x in b_lst]
ge_s1_lst = [scalar >= x for x in b_lst]
eq_s1_lst = [scalar == x for x in b_lst]
ne_s1_lst = [scalar != x for x in b_lst]

# (Scalar2) Comparison (x1=array, x2=scalar)
lt_s2 = np.less(a, scalar_np)
le_s2 = np.less_equal(a, scalar_np)
gt_s2 = np.greater(a, scalar_np)
ge_s2 = np.greater_equal(a, scalar_np)
eq_s2 = np.equal(a, scalar_np)
ne_s2 = np.not_equal(a, scalar_np)
lt_s2_lst = [x < scalar for x in a_lst]
le_s2_lst = [x <= scalar for x in a_lst]
gt_s2_lst = [x > scalar for x in a_lst]
ge_s2_lst = [x >= scalar for x in a_lst]
eq_s2_lst = [x == scalar for x in a_lst]
ne_s2_lst = [x != scalar for x in a_lst]

# Compare comparison functions (Python vs NumPy) using native Python
def compare(lt, le, gt, ge, eq, ne, lt_lst, le_lst, gt_lst, ge_lst,
eq_lst, ne_lst):
assert_(lt.tolist() == lt_lst, "Comparison function check (lt)")
assert_(le.tolist() == le_lst, "Comparison function check (le)")
assert_(gt.tolist() == gt_lst, "Comparison function check (gt)")
assert_(ge.tolist() == ge_lst, "Comparison function check (ge)")
assert_(eq.tolist() == eq_lst, "Comparison function check (eq)")
assert_(ne.tolist() == ne_lst, "Comparison function check (ne)")

# Sequence: Binary, Scalar1 and Scalar2
compare(lt_b, le_b, gt_b, ge_b, eq_b, ne_b, lt_b_lst, le_b_lst,
gt_b_lst, ge_b_lst, eq_b_lst, ne_b_lst)
compare(lt_s1, le_s1, gt_s1, ge_s1, eq_s1, ne_s1, lt_s1_lst, le_s1_lst,
gt_s1_lst, ge_s1_lst, eq_s1_lst, ne_s1_lst)
compare(lt_s2, le_s2, gt_s2, ge_s2, eq_s2, ne_s2, lt_s2_lst, le_s2_lst,
gt_s2_lst, ge_s2_lst, eq_s2_lst, ne_s2_lst)

def test_ignore_object_identity_in_equal(self):
# Check comparing identical objects whose comparison
# is not a simple boolean, e.g., arrays that are compared elementwise.
Expand Down

0 comments on commit 0bb6b36

Please sign in to comment.