Skip to content

Commit

Permalink
MAINT: Refactor and simplify the main ufunc iterator loop code
Browse files Browse the repository at this point in the history
Simple cleanups to the main ufunc loop code in preparation of
larger ones.
  • Loading branch information
seberg committed Jun 16, 2021
1 parent b5cc1f8 commit 6e144f0
Showing 1 changed file with 61 additions and 56 deletions.
117 changes: 61 additions & 56 deletions numpy/core/src/umath/ufunc_object.c
Original file line number Diff line number Diff line change
Expand Up @@ -1263,23 +1263,10 @@ iterator_loop(PyUFuncObject *ufunc,
void *innerloopdata,
npy_uint32 *op_flags)
{
npy_intp i, nin = ufunc->nin, nout = ufunc->nout;
npy_intp nop = nin + nout;
NpyIter *iter;
char *baseptrs[NPY_MAXARGS];

NpyIter_IterNextFunc *iternext;
char **dataptr;
npy_intp *stride;
npy_intp *count_ptr;
int needs_api;

PyArrayObject **op_it;
npy_uint32 iter_flags;

NPY_BEGIN_THREADS_DEF;
int nin = ufunc->nin, nout = ufunc->nout;
int nop = nin + nout;

iter_flags = ufunc->iter_flags |
npy_uint32 iter_flags = ufunc->iter_flags |
NPY_ITER_EXTERNAL_LOOP |
NPY_ITER_REFS_OK |
NPY_ITER_ZEROSIZE_OK |
Expand All @@ -1288,16 +1275,17 @@ iterator_loop(PyUFuncObject *ufunc,
NPY_ITER_DELAY_BUFALLOC |
NPY_ITER_COPY_IF_OVERLAP;

/* Call the __array_prepare__ functions for already existing output arrays.
/*
* Call the __array_prepare__ functions for already existing output arrays.
* Do this before creating the iterator, as the iterator may UPDATEIFCOPY
* some of them.
*/
for (i = 0; i < nout; ++i) {
for (int i = 0; i < nout; i++) {
if (op[nin+i] == NULL) {
continue;
}
if (prepare_ufunc_output(ufunc, &op[nin+i],
arr_prep[i], full_args, i) < 0) {
arr_prep[i], full_args, i) < 0) {
return -1;
}
}
Expand All @@ -1307,7 +1295,7 @@ iterator_loop(PyUFuncObject *ufunc,
* were already checked, we use the casting rule 'unsafe' which
* is faster to calculate.
*/
iter = NpyIter_AdvancedNew(nop, op,
NpyIter *iter = NpyIter_AdvancedNew(nop, op,
iter_flags,
order, NPY_UNSAFE_CASTING,
op_flags, dtype,
Expand All @@ -1316,17 +1304,20 @@ iterator_loop(PyUFuncObject *ufunc,
return -1;
}

/* Copy any allocated outputs */
op_it = NpyIter_GetOperandArray(iter);
for (i = 0; i < nout; ++i) {
if (op[nin+i] == NULL) {
op[nin+i] = op_it[nin+i];
Py_INCREF(op[nin+i]);
NPY_UF_DBG_PRINT("Made iterator\n");

/* Call the __array_prepare__ functions where necessary */
PyArrayObject **op_it = NpyIter_GetOperandArray(iter);
char *baseptrs[NPY_MAXARGS];

for (int i = 0; i < nout; ++i) {
if (op[nin + i] == NULL) {
op[nin + i] = op_it[nin + i];
Py_INCREF(op[nin + i]);

/* Call the __array_prepare__ functions for the new array */
if (prepare_ufunc_output(ufunc, &op[nin+i],
arr_prep[i], full_args, i) < 0) {
NpyIter_Deallocate(iter);
if (prepare_ufunc_output(ufunc,
&op[nin + i], arr_prep[i], full_args, i) < 0) {
return -1;
}

Expand All @@ -1340,45 +1331,59 @@ iterator_loop(PyUFuncObject *ufunc,
* with other operands --- the op[nin+i] array passed to it is newly
* allocated and doesn't have any overlap.
*/
baseptrs[nin+i] = PyArray_BYTES(op[nin+i]);
baseptrs[nin + i] = PyArray_BYTES(op[nin + i]);
}
else {
baseptrs[nin+i] = PyArray_BYTES(op_it[nin+i]);
baseptrs[nin + i] = PyArray_BYTES(op_it[nin + i]);
}
}

/* Only do the loop if the iteration size is non-zero */
if (NpyIter_GetIterSize(iter) != 0) {
/* Reset the iterator with the base pointers from possible __array_prepare__ */
for (i = 0; i < nin; ++i) {
baseptrs[i] = PyArray_BYTES(op_it[i]);
}
if (NpyIter_ResetBasePointers(iter, baseptrs, NULL) != NPY_SUCCEED) {
NpyIter_Deallocate(iter);
npy_intp full_size = NpyIter_GetIterSize(iter);
if (full_size == 0) {
if (!NpyIter_Deallocate(iter)) {
return -1;
}
return 0;
}

/* Get the variables needed for the loop */
iternext = NpyIter_GetIterNext(iter, NULL);
if (iternext == NULL) {
NpyIter_Deallocate(iter);
return -1;
}
dataptr = NpyIter_GetDataPtrArray(iter);
stride = NpyIter_GetInnerStrideArray(iter);
count_ptr = NpyIter_GetInnerLoopSizePtr(iter);
needs_api = NpyIter_IterationNeedsAPI(iter);
/*
* Reset the iterator with the base pointers possibly modified by
* `__array_prepare__`.
*/
for (int i = 0; i < nin; i++) {
baseptrs[i] = PyArray_BYTES(op_it[i]);
}
if (NpyIter_ResetBasePointers(iter, baseptrs, NULL) != NPY_SUCCEED) {
NpyIter_Deallocate(iter);
return -1;
}

NPY_BEGIN_THREADS_NDITER(iter);
/* Get the variables needed for the loop */
NpyIter_IterNextFunc *iternext = NpyIter_GetIterNext(iter, NULL);
if (iternext == NULL) {
NpyIter_Deallocate(iter);
return -1;
}
char **dataptr = NpyIter_GetDataPtrArray(iter);
npy_intp *strides = NpyIter_GetInnerStrideArray(iter);
npy_intp *countptr = NpyIter_GetInnerLoopSizePtr(iter);
int needs_api = NpyIter_IterationNeedsAPI(iter);

/* Execute the loop */
do {
NPY_UF_DBG_PRINT1("iterator loop count %d\n", (int)*count_ptr);
innerloop(dataptr, count_ptr, stride, innerloopdata);
} while (!(needs_api && PyErr_Occurred()) && iternext(iter));
NPY_BEGIN_THREADS_DEF;

NPY_END_THREADS;
if (!needs_api) {
NPY_BEGIN_THREADS_THRESHOLDED(full_size);
}

NPY_UF_DBG_PRINT("Actual inner loop:\n");
/* Execute the loop */
do {
NPY_UF_DBG_PRINT1("iterator loop count %d\n", (int)*count_ptr);
innerloop(dataptr, countptr, strides, innerloopdata);
} while (!(needs_api && PyErr_Occurred()) && iternext(iter));

NPY_END_THREADS;

/*
* Currently `innerloop` may leave an error set, in this case
* NpyIter_Deallocate will always return an error as well.
Expand Down

0 comments on commit 6e144f0

Please sign in to comment.