Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Adding keepdims to np.argmin,np.argmax #19211

Merged
merged 38 commits into from Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
0e5817e
keepdims added to np.argmin,np.argmax
czgdp1807 Jun 10, 2021
5d66427
Added release notes entry
czgdp1807 Jun 10, 2021
4d8f3af
tested for axis=None,keepdims=True
czgdp1807 Jun 10, 2021
aa2adc9
Apply suggestions from code review
czgdp1807 Jun 10, 2021
f7d4df4
updated interface
czgdp1807 Jun 10, 2021
1e725c7
updated interface
czgdp1807 Jun 10, 2021
4d2bfab
API changed, implementation to be done
czgdp1807 Jun 10, 2021
09433d6
Added reshape approach to C implementation
czgdp1807 Jun 11, 2021
8ea07b4
buggy implementation without reshape
czgdp1807 Jun 11, 2021
aa73907
TestArgMax, TestArgMin fixed, comments added
czgdp1807 Jun 12, 2021
22fbb06
Fixed for matrix
czgdp1807 Jun 12, 2021
d04946f
removed unrequired changes
czgdp1807 Jun 12, 2021
512b359
fixed CI failure
czgdp1807 Jun 12, 2021
da3df5b
fixed linting issue
czgdp1807 Jun 12, 2021
828df3b
PyArray_ArgMaxKeepdims now only modifies shape and strides
czgdp1807 Jun 18, 2021
91e7530
Comments added to PyArray_ArgMaxKeepdims
czgdp1807 Jun 18, 2021
a1c0faa
Updated implementation of PyArray_ArgMinKeepdims to match with PyArra…
czgdp1807 Jun 18, 2021
fa5839b
Testing complete for PyArray_ArgMinKeepdims and PyArray_ArgMaxKeepdims
czgdp1807 Jun 18, 2021
2cd3ff1
PyArray_ArgMinWithKeepdims both keepdims=True and keepdims=False
czgdp1807 Jun 19, 2021
124abe3
matched implementation of PyArray_ArgMaxKeepdims and PyArray_ArgMinKe…
czgdp1807 Jun 19, 2021
6bd9f4c
simplified implementation
czgdp1807 Jun 19, 2021
55a85d3
Added missing comment
czgdp1807 Jun 19, 2021
568251e
removed unwanted header
czgdp1807 Jun 19, 2021
9260d40
addressed all the reviews
czgdp1807 Jun 21, 2021
06d9610
Removing unwanted changes
czgdp1807 Jun 21, 2021
2112709
fixed docs
czgdp1807 Jun 21, 2021
e0dd74e
Added new lines
czgdp1807 Jun 21, 2021
11d7e33
restored annotations
czgdp1807 Jun 22, 2021
eec3e2f
Merge branch 'keepdims' of https://github.com/czgdp1807/numpy into ke…
czgdp1807 Jun 22, 2021
4c9f113
parametrized test
czgdp1807 Jun 24, 2021
db9c704
Apply suggestions from code review
czgdp1807 Jun 30, 2021
11cd597
keyword handling now done in np.argmin/np.argmax
czgdp1807 Jun 30, 2021
bb906bf
corrected indendation
czgdp1807 Jun 30, 2021
3b72f59
used with pytest.riases(ValueError)
czgdp1807 Jun 30, 2021
66adc84
fixed release notes
czgdp1807 Jun 30, 2021
4be86dd
removed PyArray_ArgMaxWithKeepdims and PyArray_ArgMinWithKeepdims fro…
czgdp1807 Jul 2, 2021
8f9f108
Apply suggestions from code review
czgdp1807 Jul 2, 2021
c55aaea
Apply suggestions from code review
czgdp1807 Jul 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions doc/release/upcoming_changes/19211.new_feature.rst
@@ -0,0 +1,9 @@
`keepdims` optional argument added to `numpy.argmin`, `numpy.argmax`
--------------------------------------------------------------------

`keepdims` argument has been added to `numpy.argmin`, `numpy.argmax`.
czgdp1807 marked this conversation as resolved.
Show resolved Hide resolved
By default, it is `False` and hence no behaviour change should be expected
in existing codes using `numpy.argmin`, `numpy.argmax`. If set to `True`,
the axes which are reduced are left in the result as dimensions with size one.
In simple words, the resulting array will be having same dimensions
as the input array.
czgdp1807 marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion doc/source/reference/c-api/array.rst
Expand Up @@ -2014,7 +2014,8 @@ Calculation


.. c:function:: PyObject* PyArray_ArgMax( \
PyArrayObject* self, int axis, PyArrayObject* out)
PyArrayObject* self, int axis, PyArrayObject* out, \
int keepdims)

Equivalent to :meth:`ndarray.argmax<numpy.ndarray.argmax>` (*self*, *axis*). Return the index of
the largest element of *self* along *axis*.
Expand Down
2 changes: 1 addition & 1 deletion numpy/__init__.cython-30.pxd
Expand Up @@ -643,7 +643,7 @@ cdef extern from "numpy/arrayobject.h":
int PyArray_Sort (ndarray, int, NPY_SORTKIND)
object PyArray_ArgSort (ndarray, int, NPY_SORTKIND)
object PyArray_SearchSorted (ndarray, object, NPY_SEARCHSIDE, PyObject *)
object PyArray_ArgMax (ndarray, int, ndarray)
object PyArray_ArgMax (ndarray, int, ndarray, int)
object PyArray_ArgMin (ndarray, int, ndarray)
object PyArray_Reshape (ndarray, object)
object PyArray_Newshape (ndarray, PyArray_Dims *, NPY_ORDER)
Expand Down
2 changes: 1 addition & 1 deletion numpy/__init__.pxd
Expand Up @@ -601,7 +601,7 @@ cdef extern from "numpy/arrayobject.h":
int PyArray_Sort (ndarray, int, NPY_SORTKIND)
object PyArray_ArgSort (ndarray, int, NPY_SORTKIND)
object PyArray_SearchSorted (ndarray, object, NPY_SEARCHSIDE, PyObject *)
object PyArray_ArgMax (ndarray, int, ndarray)
object PyArray_ArgMax (ndarray, int, ndarray, int)
czgdp1807 marked this conversation as resolved.
Show resolved Hide resolved
object PyArray_ArgMin (ndarray, int, ndarray)
object PyArray_Reshape (ndarray, object)
object PyArray_Newshape (ndarray, PyArray_Dims *, NPY_ORDER)
Expand Down
8 changes: 7 additions & 1 deletion numpy/__init__.pyi
Expand Up @@ -1299,37 +1299,43 @@ class _ArrayOrScalarCommon:
self,
axis: None = ...,
out: None = ...,
keepdims: L[False] = ...,
) -> intp: ...
@overload
def argmax(
self,
axis: _ShapeLike = ...,
out: None = ...,
keepdims: bool = ...,
) -> Any: ...
@overload
def argmax(
self,
axis: Optional[_ShapeLike] = ...,
out: _NdArraySubClass = ...,
keepdims: bool = ...,
) -> _NdArraySubClass: ...

@overload
def argmin(
self,
axis: None = ...,
out: None = ...,
keepdims: L[False] = ...,
) -> intp: ...
@overload
def argmin(
self,
axis: _ShapeLike = ...,
out: None = ...,
out: None = ...,
czgdp1807 marked this conversation as resolved.
Show resolved Hide resolved
keepdims: bool = ...,
) -> Any: ...
@overload
def argmin(
self,
axis: Optional[_ShapeLike] = ...,
out: _NdArraySubClass = ...,
keepdims: bool = ...,
) -> _NdArraySubClass: ...

def argsort(
Expand Down
55 changes: 45 additions & 10 deletions numpy/core/fromnumeric.py
Expand Up @@ -1114,12 +1114,12 @@ def argsort(a, axis=-1, kind=None, order=None):
return _wrapfunc(a, 'argsort', axis=axis, kind=kind, order=order)


def _argmax_dispatcher(a, axis=None, out=None):
return (a, out)
def _argmax_dispatcher(a, axis=None, out=None, keepdims=None):
czgdp1807 marked this conversation as resolved.
Show resolved Hide resolved
return (a, out, keepdims)
czgdp1807 marked this conversation as resolved.
Show resolved Hide resolved


@array_function_dispatch(_argmax_dispatcher)
def argmax(a, axis=None, out=None):
def argmax(a, axis=None, out=None, keepdims=False):
"""
Returns the indices of the maximum values along an axis.

Expand All @@ -1133,12 +1133,18 @@ def argmax(a, axis=None, out=None):
out : array, optional
If provided, the result will be inserted into this array. It should
be of the appropriate shape and dtype.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the array.
czgdp1807 marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
index_array : ndarray of ints
Array of indices into the array. It has the same shape as `a.shape`
with the dimension along `axis` removed.
with the dimension along `axis` removed. If `keepdims` is set to True,
then the size of `axis` will be 1 with the resulting array having same
shape as `a.shape`.

See Also
--------
Expand Down Expand Up @@ -1191,16 +1197,23 @@ def argmax(a, axis=None, out=None):
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
array([4, 3])

Setting `keepdims` to `True`,

>>> x = np.arange(24).reshape((2, 3, 4))
>>> res = np.argmax(x, axis=1, keepdims=True)
>>> res.shape
(2, 1, 4)
"""
return _wrapfunc(a, 'argmax', axis=axis, out=out)
return _wrapfunc(a, 'argmax', axis=axis, out=out,
keepdims=keepdims)


def _argmin_dispatcher(a, axis=None, out=None):
return (a, out)
def _argmin_dispatcher(a, axis=None, out=None, keepdims=None):
return (a, out, keepdims)


@array_function_dispatch(_argmin_dispatcher)
def argmin(a, axis=None, out=None):
def argmin(a, axis=None, out=None, keepdims=False):
"""
Returns the indices of the minimum values along an axis.

Expand All @@ -1214,12 +1227,18 @@ def argmin(a, axis=None, out=None):
out : array, optional
If provided, the result will be inserted into this array. It should
be of the appropriate shape and dtype.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the array.

Returns
-------
index_array : ndarray of ints
Array of indices into the array. It has the same shape as `a.shape`
with the dimension along `axis` removed.
with the dimension along `axis` removed. If `keepdims` is set to True,
then the size of `axis` will be 1 with the resulting array having same
shape as `a.shape`.

See Also
--------
Expand Down Expand Up @@ -1272,8 +1291,24 @@ def argmin(a, axis=None, out=None):
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
array([2, 0])

Setting `keepdims` to `True`,

>>> x = np.arange(24).reshape((2, 3, 4))
>>> res = np.argmin(x, axis=1, keepdims=True)
>>> res.shape
(2, 1, 4)
"""
return _wrapfunc(a, 'argmin', axis=axis, out=out)
res = _wrapfunc(a, 'argmin', axis=axis, out=out)

if keepdims:
if axis is None:
new_shape = (1,)*a.ndim
else:
new_shape = list(a.shape)
new_shape[axis] = 1
return res.reshape(new_shape)

return res


def _searchsorted_dispatcher(a, v, side=None, sorter=None):
Expand Down
4 changes: 4 additions & 0 deletions numpy/core/fromnumeric.pyi
Expand Up @@ -130,25 +130,29 @@ def argmax(
a: ArrayLike,
axis: None = ...,
out: Optional[ndarray] = ...,
keepdims: Literal[False] = ...,
) -> intp: ...
@overload
def argmax(
a: ArrayLike,
axis: Optional[int] = ...,
out: Optional[ndarray] = ...,
keepdims: bool = ...,
) -> Any: ...

@overload
def argmin(
a: ArrayLike,
axis: None = ...,
out: Optional[ndarray] = ...,
keepdims: Literal[False] = ...,
) -> intp: ...
@overload
def argmin(
a: ArrayLike,
axis: Optional[int] = ...,
out: Optional[ndarray] = ...,
keepdims: bool = ...,
) -> Any: ...

@overload
Expand Down
3 changes: 2 additions & 1 deletion numpy/core/src/multiarray/calculation.c
Expand Up @@ -38,7 +38,8 @@ power_of_ten(int n)
* ArgMax
*/
NPY_NO_EXPORT PyObject *
PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out,
int keepdims)
{
PyArrayObject *ap = NULL, *rp = NULL;
PyArray_ArgFunc* arg_func;
Expand Down
3 changes: 2 additions & 1 deletion numpy/core/src/multiarray/calculation.h
Expand Up @@ -2,7 +2,8 @@
#define _NPY_CALCULATION_H_

NPY_NO_EXPORT PyObject*
PyArray_ArgMax(PyArrayObject* self, int axis, PyArrayObject *out);
PyArray_ArgMax(PyArrayObject* self, int axis, PyArrayObject *out,
int keepdims);

NPY_NO_EXPORT PyObject*
PyArray_ArgMin(PyArrayObject* self, int axis, PyArrayObject *out);
Expand Down
4 changes: 3 additions & 1 deletion numpy/core/src/multiarray/methods.c
Expand Up @@ -284,16 +284,18 @@ array_argmax(PyArrayObject *self,
{
int axis = NPY_MAXDIMS;
PyArrayObject *out = NULL;
int keepdims = -1;
NPY_PREPARE_ARGPARSER;

if (npy_parse_arguments("argmax", args, len_args, kwnames,
"|axis", &PyArray_AxisConverter, &axis,
"|out", &PyArray_OutputConverter, &out,
"|keepdims", &PyArray_BoolConverter, &keepdims,
NULL, NULL, NULL) < 0) {
return NULL;
}

PyObject *ret = PyArray_ArgMax(self, axis, out);
PyObject *ret = PyArray_ArgMax(self, axis, out, keepdims);

/* this matches the unpacking behavior of ufuncs */
if (out == NULL) {
Expand Down
30 changes: 30 additions & 0 deletions numpy/core/tests/test_multiarray.py
Expand Up @@ -4333,7 +4333,22 @@ def test_object_argmax_with_NULLs(self):
assert_equal(a.argmax(), 3)
a[1] = 30
assert_equal(a.argmax(), 1)

def test_np_argmax_keepdims(self):

sizes = [(3,), (3, 2), (2, 3),
czgdp1807 marked this conversation as resolved.
Show resolved Hide resolved
(3, 3), (2, 3, 4), (4, 3, 2)]
for size in sizes:
arr = np.random.normal(size=size)
for axis in range(len(size)):
res = np.argmax(arr, axis=axis, keepdims=True)
assert_(res.ndim == arr.ndim)
assert_(res.shape[axis] == 1)

# Testing for axis=None, keepdims=True
res = np.argmin(arr, axis=None, keepdims=True)
assert_(res.ndim == arr.ndim)
assert_(res.shape == (1,)*arr.ndim)

class TestArgmin:
czgdp1807 marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -4489,7 +4504,22 @@ def test_object_argmin_with_NULLs(self):
assert_equal(a.argmin(), 3)
a[1] = 10
assert_equal(a.argmin(), 1)

def test_np_argmin_keepdims(self):

sizes = [(3,), (3, 2), (2, 3),
(3, 3), (2, 3, 4), (4, 3, 2)]
for size in sizes:
arr = np.random.normal(size=size)
for axis in range(len(size)):
res = np.argmin(arr, axis=axis, keepdims=True)
assert_(res.ndim == arr.ndim)
assert_(res.shape[axis] == 1)

# Testing for axis=None, keepdims=True
res = np.argmin(arr, axis=None, keepdims=True)
assert_(res.ndim == arr.ndim)
assert_(res.shape == (1,)*arr.ndim)

class TestMinMax:
czgdp1807 marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
8 changes: 4 additions & 4 deletions numpy/ma/core.py
Expand Up @@ -5491,7 +5491,7 @@ def argsort(self, axis=np._NoValue, kind=None, order=None,
filled = self.filled(fill_value)
return filled.argsort(axis=axis, kind=kind, order=order)

def argmin(self, axis=None, fill_value=None, out=None):
def argmin(self, axis=None, fill_value=None, out=None, keepdims=False):
"""
Return array of indices to the minimum values along the given axis.

Expand Down Expand Up @@ -5534,9 +5534,9 @@ def argmin(self, axis=None, fill_value=None, out=None):
if fill_value is None:
fill_value = minimum_fill_value(self)
d = self.filled(fill_value).view(ndarray)
return d.argmin(axis, out=out)
return d.argmin(axis, out=out, keepdims=keepdims)

def argmax(self, axis=None, fill_value=None, out=None):
def argmax(self, axis=None, fill_value=None, out=None, keepdims=False):
"""
Returns array of indices of the maximum values along the given axis.
Masked values are treated as if they had the value fill_value.
Expand Down Expand Up @@ -5571,7 +5571,7 @@ def argmax(self, axis=None, fill_value=None, out=None):
if fill_value is None:
fill_value = maximum_fill_value(self._data)
d = self.filled(fill_value).view(ndarray)
return d.argmax(axis, out=out)
return d.argmax(axis, out=out, keepdims=keepdims)

def sort(self, axis=-1, kind=None, order=None,
endwith=True, fill_value=None):
Expand Down
4 changes: 2 additions & 2 deletions numpy/ma/core.pyi
Expand Up @@ -270,8 +270,8 @@ class MaskedArray(ndarray[_ShapeType, _DType_co]):
def std(self, axis=..., dtype=..., out=..., ddof=..., keepdims=...): ...
def round(self, decimals=..., out=...): ...
def argsort(self, axis=..., kind=..., order=..., endwith=..., fill_value=...): ...
def argmin(self, axis=..., fill_value=..., out=...): ...
def argmax(self, axis=..., fill_value=..., out=...): ...
def argmin(self, axis=..., fill_value=..., out=..., keepdims=...): ...
def argmax(self, axis=..., fill_value=..., out=..., keepdims=...): ...
def sort(self, axis=..., kind=..., order=..., endwith=..., fill_value=...): ...
def min(self, axis=..., out=..., fill_value=..., keepdims=...): ...
# NOTE: deprecated
Expand Down