Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Adding keepdims to np.argmin,np.argmax #19211

Merged
merged 38 commits into from Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
0e5817e
keepdims added to np.argmin,np.argmax
czgdp1807 Jun 10, 2021
5d66427
Added release notes entry
czgdp1807 Jun 10, 2021
4d8f3af
tested for axis=None,keepdims=True
czgdp1807 Jun 10, 2021
aa2adc9
Apply suggestions from code review
czgdp1807 Jun 10, 2021
f7d4df4
updated interface
czgdp1807 Jun 10, 2021
1e725c7
updated interface
czgdp1807 Jun 10, 2021
4d2bfab
API changed, implementation to be done
czgdp1807 Jun 10, 2021
09433d6
Added reshape approach to C implementation
czgdp1807 Jun 11, 2021
8ea07b4
buggy implementation without reshape
czgdp1807 Jun 11, 2021
aa73907
TestArgMax, TestArgMin fixed, comments added
czgdp1807 Jun 12, 2021
22fbb06
Fixed for matrix
czgdp1807 Jun 12, 2021
d04946f
removed unrequired changes
czgdp1807 Jun 12, 2021
512b359
fixed CI failure
czgdp1807 Jun 12, 2021
da3df5b
fixed linting issue
czgdp1807 Jun 12, 2021
828df3b
PyArray_ArgMaxKeepdims now only modifies shape and strides
czgdp1807 Jun 18, 2021
91e7530
Comments added to PyArray_ArgMaxKeepdims
czgdp1807 Jun 18, 2021
a1c0faa
Updated implementation of PyArray_ArgMinKeepdims to match with PyArra…
czgdp1807 Jun 18, 2021
fa5839b
Testing complete for PyArray_ArgMinKeepdims and PyArray_ArgMaxKeepdims
czgdp1807 Jun 18, 2021
2cd3ff1
PyArray_ArgMinWithKeepdims both keepdims=True and keepdims=False
czgdp1807 Jun 19, 2021
124abe3
matched implementation of PyArray_ArgMaxKeepdims and PyArray_ArgMinKe…
czgdp1807 Jun 19, 2021
6bd9f4c
simplified implementation
czgdp1807 Jun 19, 2021
55a85d3
Added missing comment
czgdp1807 Jun 19, 2021
568251e
removed unwanted header
czgdp1807 Jun 19, 2021
9260d40
addressed all the reviews
czgdp1807 Jun 21, 2021
06d9610
Removing unwanted changes
czgdp1807 Jun 21, 2021
2112709
fixed docs
czgdp1807 Jun 21, 2021
e0dd74e
Added new lines
czgdp1807 Jun 21, 2021
11d7e33
restored annotations
czgdp1807 Jun 22, 2021
eec3e2f
Merge branch 'keepdims' of https://github.com/czgdp1807/numpy into ke…
czgdp1807 Jun 22, 2021
4c9f113
parametrized test
czgdp1807 Jun 24, 2021
db9c704
Apply suggestions from code review
czgdp1807 Jun 30, 2021
11cd597
keyword handling now done in np.argmin/np.argmax
czgdp1807 Jun 30, 2021
bb906bf
corrected indendation
czgdp1807 Jun 30, 2021
3b72f59
used with pytest.riases(ValueError)
czgdp1807 Jun 30, 2021
66adc84
fixed release notes
czgdp1807 Jun 30, 2021
4be86dd
removed PyArray_ArgMaxWithKeepdims and PyArray_ArgMinWithKeepdims fro…
czgdp1807 Jul 2, 2021
8f9f108
Apply suggestions from code review
czgdp1807 Jul 2, 2021
c55aaea
Apply suggestions from code review
czgdp1807 Jul 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/release/upcoming_changes/19211.new_feature.rst
@@ -0,0 +1,7 @@
``keepdims`` optional argument added to `numpy.argmin`, `numpy.argmax`
----------------------------------------------------------------------

``keepdims`` argument is added to `numpy.argmin`, `numpy.argmax`.
If set to ``True``, the axes which are reduced are left in the result as dimensions with size one.
The resulting array has the same number of dimensions and will broadcast with the
input array.
14 changes: 13 additions & 1 deletion numpy/__init__.pyi
Expand Up @@ -1299,37 +1299,49 @@ class _ArrayOrScalarCommon:
self,
axis: None = ...,
out: None = ...,
*,
keepdims: L[False] = ...,
) -> intp: ...
@overload
def argmax(
self,
axis: _ShapeLike = ...,
out: None = ...,
*,
keepdims: bool = ...,
) -> Any: ...
@overload
def argmax(
self,
axis: Optional[_ShapeLike] = ...,
out: _NdArraySubClass = ...,
*,
keepdims: bool = ...,
) -> _NdArraySubClass: ...

@overload
def argmin(
self,
axis: None = ...,
out: None = ...,
*,
keepdims: L[False] = ...,
) -> intp: ...
@overload
def argmin(
self,
axis: _ShapeLike = ...,
out: None = ...,
out: None = ...,
czgdp1807 marked this conversation as resolved.
Show resolved Hide resolved
*,
keepdims: bool = ...,
) -> Any: ...
@overload
def argmin(
self,
axis: Optional[_ShapeLike] = ...,
out: _NdArraySubClass = ...,
*,
keepdims: bool = ...,
) -> _NdArraySubClass: ...

def argsort(
Expand Down
42 changes: 34 additions & 8 deletions numpy/core/fromnumeric.py
Expand Up @@ -1114,12 +1114,12 @@ def argsort(a, axis=-1, kind=None, order=None):
return _wrapfunc(a, 'argsort', axis=axis, kind=kind, order=order)


def _argmax_dispatcher(a, axis=None, out=None):
def _argmax_dispatcher(a, axis=None, out=None, *, keepdims=np._NoValue):
return (a, out)


@array_function_dispatch(_argmax_dispatcher)
def argmax(a, axis=None, out=None):
def argmax(a, axis=None, out=None, *, keepdims=np._NoValue):
"""
Returns the indices of the maximum values along an axis.

Expand All @@ -1133,12 +1133,18 @@ def argmax(a, axis=None, out=None):
out : array, optional
If provided, the result will be inserted into this array. It should
be of the appropriate shape and dtype.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the array.

Returns
-------
index_array : ndarray of ints
Array of indices into the array. It has the same shape as `a.shape`
with the dimension along `axis` removed.
with the dimension along `axis` removed. If `keepdims` is set to True,
then the size of `axis` will be 1 with the resulting array having same
shape as `a.shape`.

See Also
--------
Expand Down Expand Up @@ -1191,16 +1197,23 @@ def argmax(a, axis=None, out=None):
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
array([4, 3])

Setting `keepdims` to `True`,

>>> x = np.arange(24).reshape((2, 3, 4))
>>> res = np.argmax(x, axis=1, keepdims=True)
>>> res.shape
(2, 1, 4)
"""
return _wrapfunc(a, 'argmax', axis=axis, out=out)
kwds = {'keepdims': keepdims} if keepdims is not np._NoValue else {}
return _wrapfunc(a, 'argmax', axis=axis, out=out, **kwds)


def _argmin_dispatcher(a, axis=None, out=None):
def _argmin_dispatcher(a, axis=None, out=None, *, keepdims=np._NoValue):
return (a, out)


@array_function_dispatch(_argmin_dispatcher)
def argmin(a, axis=None, out=None):
def argmin(a, axis=None, out=None, *, keepdims=np._NoValue):
"""
Returns the indices of the minimum values along an axis.

Expand All @@ -1214,12 +1227,18 @@ def argmin(a, axis=None, out=None):
out : array, optional
If provided, the result will be inserted into this array. It should
be of the appropriate shape and dtype.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the array.

Returns
-------
index_array : ndarray of ints
Array of indices into the array. It has the same shape as `a.shape`
with the dimension along `axis` removed.
with the dimension along `axis` removed. If `keepdims` is set to True,
then the size of `axis` will be 1 with the resulting array having same
shape as `a.shape`.

See Also
--------
Expand Down Expand Up @@ -1272,8 +1291,15 @@ def argmin(a, axis=None, out=None):
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
array([2, 0])

Setting `keepdims` to `True`,

>>> x = np.arange(24).reshape((2, 3, 4))
>>> res = np.argmin(x, axis=1, keepdims=True)
>>> res.shape
(2, 1, 4)
"""
return _wrapfunc(a, 'argmin', axis=axis, out=out)
kwds = {'keepdims': keepdims} if keepdims is not np._NoValue else {}
return _wrapfunc(a, 'argmin', axis=axis, out=out, **kwds)


def _searchsorted_dispatcher(a, v, side=None, sorter=None):
Expand Down
8 changes: 8 additions & 0 deletions numpy/core/fromnumeric.pyi
Expand Up @@ -130,25 +130,33 @@ def argmax(
a: ArrayLike,
axis: None = ...,
out: Optional[ndarray] = ...,
*,
keepdims: Literal[False] = ...,
) -> intp: ...
@overload
def argmax(
a: ArrayLike,
axis: Optional[int] = ...,
out: Optional[ndarray] = ...,
*,
keepdims: bool = ...,
) -> Any: ...

@overload
def argmin(
a: ArrayLike,
axis: None = ...,
out: Optional[ndarray] = ...,
*,
keepdims: Literal[False] = ...,
) -> intp: ...
@overload
def argmin(
a: ArrayLike,
axis: Optional[int] = ...,
out: Optional[ndarray] = ...,
*,
keepdims: bool = ...,
) -> Any: ...

@overload
Expand Down
103 changes: 88 additions & 15 deletions numpy/core/src/multiarray/calculation.c
Expand Up @@ -34,18 +34,24 @@ power_of_ten(int n)
return ret;
}

/*NUMPY_API
* ArgMax
*/
NPY_NO_EXPORT PyObject *
PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
_PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
int axis, PyArrayObject *out, int keepdims)
{
PyArrayObject *ap = NULL, *rp = NULL;
PyArray_ArgFunc* arg_func;
char *ip;
npy_intp *rptr;
npy_intp i, n, m;
int elsize;
// Keep a copy because axis changes via call to PyArray_CheckAxis
int axis_copy = axis;
npy_intp _shape_buf[NPY_MAXDIMS];
npy_intp *out_shape;
// Keep the number of dimensions and shape of
// original array. Helps when `keepdims` is True.
npy_intp* original_op_shape = PyArray_DIMS(op);
int out_ndim = PyArray_NDIM(op);
NPY_BEGIN_THREADS_DEF;

if ((ap = (PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) {
Expand Down Expand Up @@ -86,6 +92,29 @@ PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
if (ap == NULL) {
return NULL;
}

// Decides the shape of the output array.
if (!keepdims) {
out_ndim = PyArray_NDIM(ap) - 1;
out_shape = PyArray_DIMS(ap);
}
else {
out_shape = _shape_buf;
if (axis_copy == NPY_MAXDIMS) {
for (int i = 0; i < out_ndim; i++) {
out_shape[i] = 1;
}
}
else {
/*
* While `ap` may be transposed, we can ignore this for `out` because the
* transpose only reorders the size 1 `axis` (not changing memory layout).
*/
memcpy(out_shape, original_op_shape, out_ndim * sizeof(npy_intp));
out_shape[axis] = 1;
}
}

arg_func = PyArray_DESCR(ap)->f->argmax;
if (arg_func == NULL) {
PyErr_SetString(PyExc_TypeError,
Expand All @@ -103,16 +132,16 @@ PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
if (!out) {
rp = (PyArrayObject *)PyArray_NewFromDescr(
Py_TYPE(ap), PyArray_DescrFromType(NPY_INTP),
PyArray_NDIM(ap) - 1, PyArray_DIMS(ap), NULL, NULL,
out_ndim, out_shape, NULL, NULL,
0, (PyObject *)ap);
if (rp == NULL) {
goto fail;
}
}
else {
if ((PyArray_NDIM(out) != PyArray_NDIM(ap) - 1) ||
!PyArray_CompareLists(PyArray_DIMS(out), PyArray_DIMS(ap),
PyArray_NDIM(out))) {
czgdp1807 marked this conversation as resolved.
Show resolved Hide resolved
if ((PyArray_NDIM(out) != out_ndim) ||
!PyArray_CompareLists(PyArray_DIMS(out), out_shape,
out_ndim)) {
PyErr_SetString(PyExc_ValueError,
"output array does not match result of np.argmax.");
goto fail;
Expand All @@ -135,7 +164,7 @@ PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
NPY_END_THREADS_DESCR(PyArray_DESCR(ap));

Py_DECREF(ap);
/* Trigger the UPDATEIFCOPY/WRTIEBACKIFCOPY if necessary */
/* Trigger the UPDATEIFCOPY/WRITEBACKIFCOPY if necessary */
if (out != NULL && out != rp) {
PyArray_ResolveWritebackIfCopy(rp);
Py_DECREF(rp);
Expand All @@ -151,17 +180,32 @@ PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
}

/*NUMPY_API
* ArgMin
* ArgMax
*/
NPY_NO_EXPORT PyObject *
PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out)
PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
{
return _PyArray_ArgMaxWithKeepdims(op, axis, out, 0);
}

NPY_NO_EXPORT PyObject *
_PyArray_ArgMinWithKeepdims(PyArrayObject *op,
int axis, PyArrayObject *out, int keepdims)
{
PyArrayObject *ap = NULL, *rp = NULL;
PyArray_ArgFunc* arg_func;
char *ip;
npy_intp *rptr;
npy_intp i, n, m;
int elsize;
// Keep a copy because axis changes via call to PyArray_CheckAxis
int axis_copy = axis;
npy_intp _shape_buf[NPY_MAXDIMS];
npy_intp *out_shape;
// Keep the number of dimensions and shape of
// original array. Helps when `keepdims` is True.
npy_intp* original_op_shape = PyArray_DIMS(op);
int out_ndim = PyArray_NDIM(op);
NPY_BEGIN_THREADS_DEF;

if ((ap = (PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) {
Expand Down Expand Up @@ -202,6 +246,27 @@ PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out)
if (ap == NULL) {
return NULL;
}

// Decides the shape of the output array.
if (!keepdims) {
out_ndim = PyArray_NDIM(ap) - 1;
out_shape = PyArray_DIMS(ap);
} else {
out_shape = _shape_buf;
if (axis_copy == NPY_MAXDIMS) {
for (int i = 0; i < out_ndim; i++) {
out_shape[i] = 1;
}
} else {
/*
* While `ap` may be transposed, we can ignore this for `out` because the
* transpose only reorders the size 1 `axis` (not changing memory layout).
*/
memcpy(out_shape, original_op_shape, out_ndim * sizeof(npy_intp));
out_shape[axis] = 1;
}
}

arg_func = PyArray_DESCR(ap)->f->argmin;
if (arg_func == NULL) {
PyErr_SetString(PyExc_TypeError,
Expand All @@ -219,16 +284,15 @@ PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out)
if (!out) {
rp = (PyArrayObject *)PyArray_NewFromDescr(
Py_TYPE(ap), PyArray_DescrFromType(NPY_INTP),
PyArray_NDIM(ap) - 1, PyArray_DIMS(ap), NULL, NULL,
out_ndim, out_shape, NULL, NULL,
0, (PyObject *)ap);
if (rp == NULL) {
goto fail;
}
}
else {
if ((PyArray_NDIM(out) != PyArray_NDIM(ap) - 1) ||
!PyArray_CompareLists(PyArray_DIMS(out), PyArray_DIMS(ap),
PyArray_NDIM(out))) {
if ((PyArray_NDIM(out) != out_ndim) ||
!PyArray_CompareLists(PyArray_DIMS(out), out_shape, out_ndim)) {
PyErr_SetString(PyExc_ValueError,
"output array does not match result of np.argmin.");
goto fail;
Expand Down Expand Up @@ -266,6 +330,15 @@ PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out)
return NULL;
}

/*NUMPY_API
* ArgMin
*/
NPY_NO_EXPORT PyObject *
PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out)
{
return _PyArray_ArgMinWithKeepdims(op, axis, out, 0);
}

/*NUMPY_API
* Max
*/
Expand Down
6 changes: 6 additions & 0 deletions numpy/core/src/multiarray/calculation.h
Expand Up @@ -4,9 +4,15 @@
NPY_NO_EXPORT PyObject*
PyArray_ArgMax(PyArrayObject* self, int axis, PyArrayObject *out);

NPY_NO_EXPORT PyObject*
_PyArray_ArgMaxWithKeepdims(PyArrayObject* self, int axis, PyArrayObject *out, int keepdims);

NPY_NO_EXPORT PyObject*
PyArray_ArgMin(PyArrayObject* self, int axis, PyArrayObject *out);

NPY_NO_EXPORT PyObject*
_PyArray_ArgMinWithKeepdims(PyArrayObject* self, int axis, PyArrayObject *out, int keepdims);

NPY_NO_EXPORT PyObject*
PyArray_Max(PyArrayObject* self, int axis, PyArrayObject* out);

Expand Down