Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add equals_nans kwarg to np.unique #21623

Merged
merged 5 commits into from Jun 1, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/release/upcoming_changes/21623.new_feature.rst
@@ -0,0 +1,4 @@
New parameter ``equal_nans`` added to `np.unique`
-----------------------------------------------------------------------------------

`np.unique` was previously changed to treat NaN values as equal. Now this functionality is decided by setting the ``equal_nans`` kwarg to True or False. True is the default behavior.
rjeb marked this conversation as resolved.
Show resolved Hide resolved
18 changes: 11 additions & 7 deletions numpy/lib/arraysetops.py
Expand Up @@ -131,13 +131,13 @@ def _unpack_tuple(x):


def _unique_dispatcher(ar, return_index=None, return_inverse=None,
return_counts=None, axis=None):
return_counts=None, axis=None, *, equal_nans=None):
return (ar,)


@array_function_dispatch(_unique_dispatcher)
def unique(ar, return_index=False, return_inverse=False,
return_counts=False, axis=None):
return_counts=False, axis=None, *, equal_nans=True):
"""
Find the unique elements of an array.

Expand All @@ -162,8 +162,10 @@ def unique(ar, return_index=False, return_inverse=False,
return_counts : bool, optional
If True, also return the number of times each unique item appears
in `ar`.
equal_nans : bool, optional
If True, collapses multiple NaN values in return array into 1

.. versionadded:: 1.9.0
.. versionchanged: NumPy 1.24
rjeb marked this conversation as resolved.
Show resolved Hide resolved

axis : int or None, optional
The axis to operate on. If None, `ar` will be flattened. If an integer,
Expand Down Expand Up @@ -269,7 +271,8 @@ def unique(ar, return_index=False, return_inverse=False,
"""
ar = np.asanyarray(ar)
if axis is None:
ret = _unique1d(ar, return_index, return_inverse, return_counts)
ret = _unique1d(ar, return_index, return_inverse, return_counts,
equal_nans = equal_nans)
return _unpack_tuple(ret)

# axis was specified and not None
Expand Down Expand Up @@ -312,13 +315,13 @@ def reshape_uniq(uniq):
return uniq

output = _unique1d(consolidated, return_index,
return_inverse, return_counts)
return_inverse, return_counts, equal_nans = equal_nans)
output = (reshape_uniq(output[0]),) + output[1:]
return _unpack_tuple(output)


def _unique1d(ar, return_index=False, return_inverse=False,
return_counts=False):
return_counts=False, *, equal_nans=True):
"""
Find the unique elements of an array, ignoring shape.
"""
Expand All @@ -334,7 +337,8 @@ def _unique1d(ar, return_index=False, return_inverse=False,
aux = ar
mask = np.empty(aux.shape, dtype=np.bool_)
mask[:1] = True
if aux.shape[0] > 0 and aux.dtype.kind in "cfmM" and np.isnan(aux[-1]):
if (equal_nans and aux.shape[0] > 0 and aux.dtype.kind in "cfmM" and
np.isnan(aux[-1])):
rjeb marked this conversation as resolved.
Show resolved Hide resolved
if aux.dtype.kind == "c": # for complex all NaNs are considered equivalent
aux_firstnan = np.searchsorted(np.isnan(aux), True, side='left')
else:
Expand Down
8 changes: 8 additions & 0 deletions numpy/lib/tests/test_arraysetops.py
Expand Up @@ -765,3 +765,11 @@ def _run_axis_tests(self, dtype):
assert_array_equal(uniq[:, inv], data)
msg = "Unique's return_counts=True failed with axis=1"
assert_array_equal(cnt, np.array([2, 1, 1]), msg)

def test_unique_nanequals(self):
# issue 20326
a = np.array([1, 1, np.nan, np.nan, np.nan])
unq = np.unique(a)
not_unq = np.unique(a, equal_nans = False)
assert_array_equal(unq, np.array([1, np.nan]))
assert_array_equal(not_unq, np.array([1, np.nan, np.nan, np.nan]))