Skip to content

Commit

Permalink
TST: Add tests for the new nan<x> function parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Bas van Beek committed Oct 4, 2021
1 parent c7ca470 commit 986e22a
Showing 1 changed file with 103 additions and 0 deletions.
103 changes: 103 additions & 0 deletions numpy/lib/tests/test_nanfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,46 @@ def test_object_array(self):
assert_(len(w) == 1, 'no warning raised')
assert_(issubclass(w[0].category, RuntimeWarning))

@pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
def test_initial(self, dtype):
class MyNDArray(np.ndarray):
pass

ar = np.arange(9).astype(dtype)
ar[:5] = np.nan

for f in self.nanfuncs:
initial = 100 if f is np.nanmax else 0

ret1 = f(ar, initial=initial)
assert ret1.dtype == dtype
assert ret1 == initial

ret2 = f(ar.view(MyNDArray), initial=initial)
assert ret2.dtype == dtype
assert ret2 == initial

@pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
def test_where(self, dtype):
class MyNDArray(np.ndarray):
pass

ar = np.arange(9).reshape(3, 3).astype(dtype)
ar[0, :] = np.nan
where = np.ones_like(ar, dtype=np.bool_)
where[:, 0] = False

for f in self.nanfuncs:
reference = 4 if f is np.nanmin else 8

ret1 = f(ar, where=where, initial=5)
assert ret1.dtype == dtype
assert ret1 == reference

ret2 = f(ar.view(MyNDArray), where=where, initial=5)
assert ret2.dtype == dtype
assert ret2 == reference


class TestNanFunctions_ArgminArgmax:

Expand Down Expand Up @@ -288,6 +328,30 @@ class MyNDArray(np.ndarray):
res = f(mine)
assert_(res.shape == ())

@pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
def test_keepdims(self, dtype):
ar = np.arange(9).astype(dtype)
ar[:5] = np.nan

for f in self.nanfuncs:
reference = 5 if f is np.nanargmin else 8
ret = f(ar, keepdims=True)
assert ret.ndim == ar.ndim
assert ret == reference

@pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
def test_out(self, dtype):
ar = np.arange(9).astype(dtype)
ar[:5] = np.nan

for f in self.nanfuncs:
out = np.zeros((), dtype=np.intp)
reference = 5 if f is np.nanargmin else 8
ret = f(ar, out=out)
assert ret is out
assert ret == reference



_TEST_ARRAYS = {
"0d": np.array(5),
Expand Down Expand Up @@ -504,6 +568,30 @@ def test_empty(self):
res = f(mat, axis=None)
assert_equal(res, tgt)

@pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
def test_initial(self, dtype):
ar = np.arange(9).astype(dtype)
ar[:5] = np.nan

for f in self.nanfuncs:
reference = 28 if f is np.nansum else 3360
ret = f(ar, initial=2)
assert ret.dtype == dtype
assert ret == reference

@pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
def test_where(self, dtype):
ar = np.arange(9).reshape(3, 3).astype(dtype)
ar[0, :] = np.nan
where = np.ones_like(ar, dtype=np.bool_)
where[:, 0] = False

for f in self.nanfuncs:
reference = 26 if f is np.nansum else 2240
ret = f(ar, where=where, initial=2)
assert ret.dtype == dtype
assert ret == reference


class TestNanFunctions_CumSumProd(SharedNanFunctionsTestsMixin):

Expand Down Expand Up @@ -659,6 +747,21 @@ def test_empty(self):
assert_equal(f(mat, axis=axis), np.zeros([]))
assert_(len(w) == 0)

@pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
def test_where(self, dtype):
ar = np.arange(9).reshape(3, 3).astype(dtype)
ar[0, :] = np.nan
where = np.ones_like(ar, dtype=np.bool_)
where[:, 0] = False

for f, f_std in zip(self.nanfuncs, self.stdfuncs):
reference = f_std(ar[where][2:])
dtype_reference = dtype if f is np.nanmean else ar.real.dtype

ret = f(ar, where=where)
assert ret.dtype == dtype_reference
np.testing.assert_allclose(ret, reference)


_TIME_UNITS = (
"Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns", "ps", "fs", "as"
Expand Down

0 comments on commit 986e22a

Please sign in to comment.