Skip to content

Commit

Permalink
Merge pull request #21644 from seberg/unique-equal-nan
Browse files Browse the repository at this point in the history
MAINT: Fixup `unique`s `equal_nan` kwarg to match `np.array_equal`
  • Loading branch information
mattip committed Jun 1, 2022
2 parents 6cada27 + 911015e commit 07709f3
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
6 changes: 3 additions & 3 deletions doc/release/upcoming_changes/21623.new_feature.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
New parameter ``equal_nans`` added to `np.unique`
-----------------------------------------------------------------------------------
New parameter ``equal_nan`` added to `np.unique`
------------------------------------------------

`np.unique` was changed in 1.21 to treat all ``NaN`` values as equal and return
a single ``NaN``. Setting ``equal_nans=False`` will restore pre-1.21 behavior
a single ``NaN``. Setting ``equal_nan=False`` will restore pre-1.21 behavior
to treat ``NaNs`` as unique. Defaults to ``True``.
22 changes: 11 additions & 11 deletions numpy/lib/arraysetops.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ def _unpack_tuple(x):


def _unique_dispatcher(ar, return_index=None, return_inverse=None,
return_counts=None, axis=None, *, equal_nans=None):
return_counts=None, axis=None, *, equal_nan=None):
return (ar,)


@array_function_dispatch(_unique_dispatcher)
def unique(ar, return_index=False, return_inverse=False,
return_counts=False, axis=None, *, equal_nans=True):
return_counts=False, axis=None, *, equal_nan=True):
"""
Find the unique elements of an array.
Expand All @@ -162,11 +162,6 @@ def unique(ar, return_index=False, return_inverse=False,
return_counts : bool, optional
If True, also return the number of times each unique item appears
in `ar`.
equal_nans : bool, optional
If True, collapses multiple NaN values in return array into 1
.. versionchanged: 1.24
axis : int or None, optional
The axis to operate on. If None, `ar` will be flattened. If an integer,
the subarrays indexed by the given axis will be flattened and treated
Expand All @@ -177,6 +172,11 @@ def unique(ar, return_index=False, return_inverse=False,
.. versionadded:: 1.13.0
equal_nan : bool, optional
If True, collapses multiple NaN values in the return array into one.
.. versionadded:: 1.24
Returns
-------
unique : ndarray
Expand Down Expand Up @@ -272,7 +272,7 @@ def unique(ar, return_index=False, return_inverse=False,
ar = np.asanyarray(ar)
if axis is None:
ret = _unique1d(ar, return_index, return_inverse, return_counts,
equal_nans = equal_nans)
equal_nan=equal_nan)
return _unpack_tuple(ret)

# axis was specified and not None
Expand Down Expand Up @@ -315,13 +315,13 @@ def reshape_uniq(uniq):
return uniq

output = _unique1d(consolidated, return_index,
return_inverse, return_counts, equal_nans = equal_nans)
return_inverse, return_counts, equal_nan=equal_nan)
output = (reshape_uniq(output[0]),) + output[1:]
return _unpack_tuple(output)


def _unique1d(ar, return_index=False, return_inverse=False,
return_counts=False, *, equal_nans=True):
return_counts=False, *, equal_nan=True):
"""
Find the unique elements of an array, ignoring shape.
"""
Expand All @@ -337,7 +337,7 @@ def _unique1d(ar, return_index=False, return_inverse=False,
aux = ar
mask = np.empty(aux.shape, dtype=np.bool_)
mask[:1] = True
if (equal_nans and aux.shape[0] > 0 and aux.dtype.kind in "cfmM" and
if (equal_nan and aux.shape[0] > 0 and aux.dtype.kind in "cfmM" and
np.isnan(aux[-1])):
if aux.dtype.kind == "c": # for complex all NaNs are considered equivalent
aux_firstnan = np.searchsorted(np.isnan(aux), True, side='left')
Expand Down
2 changes: 1 addition & 1 deletion numpy/lib/tests/test_arraysetops.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,6 @@ def test_unique_nanequals(self):
# issue 20326
a = np.array([1, 1, np.nan, np.nan, np.nan])
unq = np.unique(a)
not_unq = np.unique(a, equal_nans = False)
not_unq = np.unique(a, equal_nan=False)
assert_array_equal(unq, np.array([1, np.nan]))
assert_array_equal(not_unq, np.array([1, np.nan, np.nan, np.nan]))

0 comments on commit 07709f3

Please sign in to comment.