Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: factored out _PyArray_ArgMinMaxCommon #19440

Merged
merged 4 commits into from
Jul 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
176 changes: 29 additions & 147 deletions numpy/core/src/multiarray/calculation.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ power_of_ten(int n)
}

NPY_NO_EXPORT PyObject *
_PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
int axis, PyArrayObject *out, int keepdims)
_PyArray_ArgMinMaxCommon(PyArrayObject *op,
int axis, PyArrayObject *out, int keepdims,
npy_bool is_argmax)
mattip marked this conversation as resolved.
Show resolved Hide resolved
{
PyArrayObject *ap = NULL, *rp = NULL;
PyArray_ArgFunc* arg_func;
char *ip;
PyArray_ArgFunc* arg_func = NULL;
char *ip, *func_name;
npy_intp *rptr;
npy_intp i, n, m;
int elsize;
Expand Down Expand Up @@ -115,7 +116,14 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
}
}

arg_func = PyArray_DESCR(ap)->f->argmax;
if (is_argmax) {
func_name = "argmax";
arg_func = PyArray_DESCR(ap)->f->argmax;
}
else {
func_name = "argmin";
arg_func = PyArray_DESCR(ap)->f->argmin;
}
if (arg_func == NULL) {
PyErr_SetString(PyExc_TypeError,
"data type not ordered");
Expand All @@ -124,8 +132,9 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
elsize = PyArray_DESCR(ap)->elsize;
m = PyArray_DIMS(ap)[PyArray_NDIM(ap)-1];
if (m == 0) {
PyErr_SetString(PyExc_ValueError,
"attempt to get argmax of an empty sequence");
PyErr_Format(PyExc_ValueError,
"attempt to get %s of an empty sequence",
func_name);
goto fail;
}

Expand All @@ -142,8 +151,9 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
if ((PyArray_NDIM(out) != out_ndim) ||
!PyArray_CompareLists(PyArray_DIMS(out), out_shape,
out_ndim)) {
PyErr_SetString(PyExc_ValueError,
"output array does not match result of np.argmax.");
PyErr_Format(PyExc_ValueError,
"output array does not match result of np.%s.",
func_name);
goto fail;
}
rp = (PyArrayObject *)PyArray_FromArray(out,
Expand Down Expand Up @@ -179,155 +189,27 @@ _PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
return NULL;
}

NPY_NO_EXPORT PyObject*
_PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
int axis, PyArrayObject *out, int keepdims)
{
return _PyArray_ArgMinMaxCommon(op, axis, out, keepdims, 1);
}

/*NUMPY_API
* ArgMax
*/
NPY_NO_EXPORT PyObject *
PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
{
return _PyArray_ArgMaxWithKeepdims(op, axis, out, 0);
return _PyArray_ArgMinMaxCommon(op, axis, out, 0, 1);
}

NPY_NO_EXPORT PyObject *
_PyArray_ArgMinWithKeepdims(PyArrayObject *op,
int axis, PyArrayObject *out, int keepdims)
{
PyArrayObject *ap = NULL, *rp = NULL;
PyArray_ArgFunc* arg_func;
char *ip;
npy_intp *rptr;
npy_intp i, n, m;
int elsize;
// Keep a copy because axis changes via call to PyArray_CheckAxis
int axis_copy = axis;
npy_intp _shape_buf[NPY_MAXDIMS];
npy_intp *out_shape;
// Keep the number of dimensions and shape of
// original array. Helps when `keepdims` is True.
npy_intp* original_op_shape = PyArray_DIMS(op);
int out_ndim = PyArray_NDIM(op);
NPY_BEGIN_THREADS_DEF;

if ((ap = (PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) {
return NULL;
}
/*
* We need to permute the array so that axis is placed at the end.
* And all other dimensions are shifted left.
*/
if (axis != PyArray_NDIM(ap)-1) {
PyArray_Dims newaxes;
npy_intp dims[NPY_MAXDIMS];
int i;

newaxes.ptr = dims;
newaxes.len = PyArray_NDIM(ap);
for (i = 0; i < axis; i++) {
dims[i] = i;
}
for (i = axis; i < PyArray_NDIM(ap) - 1; i++) {
dims[i] = i + 1;
}
dims[PyArray_NDIM(ap) - 1] = axis;
op = (PyArrayObject *)PyArray_Transpose(ap, &newaxes);
Py_DECREF(ap);
if (op == NULL) {
return NULL;
}
}
else {
op = ap;
}

/* Will get native-byte order contiguous copy. */
ap = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)op,
PyArray_DESCR(op)->type_num, 1, 0);
Py_DECREF(op);
if (ap == NULL) {
return NULL;
}

// Decides the shape of the output array.
if (!keepdims) {
out_ndim = PyArray_NDIM(ap) - 1;
out_shape = PyArray_DIMS(ap);
} else {
out_shape = _shape_buf;
if (axis_copy == NPY_MAXDIMS) {
for (int i = 0; i < out_ndim; i++) {
out_shape[i] = 1;
}
} else {
/*
* While `ap` may be transposed, we can ignore this for `out` because the
* transpose only reorders the size 1 `axis` (not changing memory layout).
*/
memcpy(out_shape, original_op_shape, out_ndim * sizeof(npy_intp));
out_shape[axis] = 1;
}
}

arg_func = PyArray_DESCR(ap)->f->argmin;
if (arg_func == NULL) {
PyErr_SetString(PyExc_TypeError,
"data type not ordered");
goto fail;
}
elsize = PyArray_DESCR(ap)->elsize;
m = PyArray_DIMS(ap)[PyArray_NDIM(ap)-1];
if (m == 0) {
PyErr_SetString(PyExc_ValueError,
"attempt to get argmin of an empty sequence");
mattip marked this conversation as resolved.
Show resolved Hide resolved
goto fail;
}

if (!out) {
rp = (PyArrayObject *)PyArray_NewFromDescr(
Py_TYPE(ap), PyArray_DescrFromType(NPY_INTP),
out_ndim, out_shape, NULL, NULL,
0, (PyObject *)ap);
if (rp == NULL) {
goto fail;
}
}
else {
if ((PyArray_NDIM(out) != out_ndim) ||
!PyArray_CompareLists(PyArray_DIMS(out), out_shape, out_ndim)) {
PyErr_SetString(PyExc_ValueError,
"output array does not match result of np.argmin.");
goto fail;
}
rp = (PyArrayObject *)PyArray_FromArray(out,
PyArray_DescrFromType(NPY_INTP),
NPY_ARRAY_CARRAY | NPY_ARRAY_WRITEBACKIFCOPY);
if (rp == NULL) {
goto fail;
}
}

NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap));
n = PyArray_SIZE(ap)/m;
rptr = (npy_intp *)PyArray_DATA(rp);
for (ip = PyArray_DATA(ap), i = 0; i < n; i++, ip += elsize*m) {
arg_func(ip, m, rptr, ap);
rptr += 1;
}
NPY_END_THREADS_DESCR(PyArray_DESCR(ap));

Py_DECREF(ap);
/* Trigger the UPDATEIFCOPY/WRITEBACKIFCOPY if necessary */
if (out != NULL && out != rp) {
PyArray_ResolveWritebackIfCopy(rp);
Py_DECREF(rp);
rp = out;
Py_INCREF(rp);
}
return (PyObject *)rp;

fail:
Py_DECREF(ap);
Py_XDECREF(rp);
return NULL;
return _PyArray_ArgMinMaxCommon(op, axis, out, keepdims, 0);
}

/*NUMPY_API
Expand All @@ -336,7 +218,7 @@ _PyArray_ArgMinWithKeepdims(PyArrayObject *op,
NPY_NO_EXPORT PyObject *
PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out)
{
return _PyArray_ArgMinWithKeepdims(op, axis, out, 0);
return _PyArray_ArgMinMaxCommon(op, axis, out, 0, 0);
}

/*NUMPY_API
Expand Down