Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add equals_nan kwarg to np.unique #21646

Merged
merged 6 commits into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/release/upcoming_changes/21623.new_feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
New parameter ``equal_nan`` added to `np.unique`
------------------------------------------------

`np.unique` was changed in 1.21 to treat all ``NaN`` values as equal and return
a single ``NaN``. Setting ``equal_nan=False`` will restore pre-1.21 behavior
to treat ``NaNs`` as unique. Defaults to ``True``.
22 changes: 13 additions & 9 deletions numpy/lib/arraysetops.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ def _unpack_tuple(x):


def _unique_dispatcher(ar, return_index=None, return_inverse=None,
return_counts=None, axis=None):
return_counts=None, axis=None, *, equal_nan=None):
return (ar,)


@array_function_dispatch(_unique_dispatcher)
def unique(ar, return_index=False, return_inverse=False,
return_counts=False, axis=None):
return_counts=False, axis=None, *, equal_nan=True):
"""
Find the unique elements of an array.

Expand All @@ -162,9 +162,6 @@ def unique(ar, return_index=False, return_inverse=False,
return_counts : bool, optional
If True, also return the number of times each unique item appears
in `ar`.

.. versionadded:: 1.9.0

axis : int or None, optional
The axis to operate on. If None, `ar` will be flattened. If an integer,
the subarrays indexed by the given axis will be flattened and treated
Expand All @@ -175,6 +172,11 @@ def unique(ar, return_index=False, return_inverse=False,

.. versionadded:: 1.13.0

equal_nan : bool, optional
If True, collapses multiple NaN values in the return array into one.

.. versionadded:: 1.24

Returns
-------
unique : ndarray
Expand Down Expand Up @@ -269,7 +271,8 @@ def unique(ar, return_index=False, return_inverse=False,
"""
ar = np.asanyarray(ar)
if axis is None:
ret = _unique1d(ar, return_index, return_inverse, return_counts)
ret = _unique1d(ar, return_index, return_inverse, return_counts,
equal_nan=equal_nan)
return _unpack_tuple(ret)

# axis was specified and not None
Expand Down Expand Up @@ -312,13 +315,13 @@ def reshape_uniq(uniq):
return uniq

output = _unique1d(consolidated, return_index,
return_inverse, return_counts)
return_inverse, return_counts, equal_nan=equal_nan)
output = (reshape_uniq(output[0]),) + output[1:]
return _unpack_tuple(output)


def _unique1d(ar, return_index=False, return_inverse=False,
return_counts=False):
return_counts=False, *, equal_nan=True):
"""
Find the unique elements of an array, ignoring shape.
"""
Expand All @@ -334,7 +337,8 @@ def _unique1d(ar, return_index=False, return_inverse=False,
aux = ar
mask = np.empty(aux.shape, dtype=np.bool_)
mask[:1] = True
if aux.shape[0] > 0 and aux.dtype.kind in "cfmM" and np.isnan(aux[-1]):
if (equal_nan and aux.shape[0] > 0 and aux.dtype.kind in "cfmM" and
np.isnan(aux[-1])):
if aux.dtype.kind == "c": # for complex all NaNs are considered equivalent
aux_firstnan = np.searchsorted(np.isnan(aux), True, side='left')
else:
Expand Down
8 changes: 8 additions & 0 deletions numpy/lib/tests/test_arraysetops.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,3 +765,11 @@ def _run_axis_tests(self, dtype):
assert_array_equal(uniq[:, inv], data)
msg = "Unique's return_counts=True failed with axis=1"
assert_array_equal(cnt, np.array([2, 1, 1]), msg)

def test_unique_nanequals(self):
# issue 20326
a = np.array([1, 1, np.nan, np.nan, np.nan])
unq = np.unique(a)
not_unq = np.unique(a, equal_nan=False)
assert_array_equal(unq, np.array([1, np.nan]))
assert_array_equal(not_unq, np.array([1, np.nan, np.nan, np.nan]))