Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add overflow handling for negative integers scalar multiplication #21793

Merged
merged 10 commits into from
Jul 17, 2022
2 changes: 1 addition & 1 deletion numpy/core/src/common/npy_hashtable.c
Expand Up @@ -146,7 +146,7 @@ _resize_if_necessary(PyArrayIdentityHash *tb)
}

npy_intp alloc_size;
if (npy_mul_with_overflow_intp(&alloc_size, new_size, tb->key_len + 1)) {
if (npy_mul_sizes_with_overflow(&alloc_size, new_size, tb->key_len + 1)) {
return -1;
}
tb->buckets = PyMem_Calloc(alloc_size, sizeof(PyObject *));
Expand Down
42 changes: 36 additions & 6 deletions numpy/core/src/common/templ_common.h.src
Expand Up @@ -7,13 +7,15 @@

/**begin repeat
* #name = int, uint, long, ulong,
* longlong, ulonglong, intp#
* longlong, ulonglong#
* #type = npy_int, npy_uint, npy_long, npy_ulong,
* npy_longlong, npy_ulonglong, npy_intp#
* npy_longlong, npy_ulonglong#
* #MAX = NPY_MAX_INT, NPY_MAX_UINT, NPY_MAX_LONG, NPY_MAX_ULONG,
* NPY_MAX_LONGLONG, NPY_MAX_ULONGLONG, NPY_MAX_INTP#
* NPY_MAX_LONGLONG, NPY_MAX_ULONGLONG#
*
* #neg = (1,0)*3#
* #abs_func = abs*2, labs*2, llabs*2#
*/

/*
* writes result of a * b into r
* returns 1 if a * b overflowed else returns 0
Expand All @@ -34,13 +36,41 @@ npy_mul_with_overflow_@name@(@type@ * r, @type@ a, @type@ b)
/*
* avoid expensive division on common no overflow case
*/
if (NPY_UNLIKELY((a | b) >= half_sz) &&
a != 0 && b > @MAX@ / a) {
if ((NPY_UNLIKELY((a | b) >= half_sz) || (a | b) < 0) &&
a != 0 &&
#if @neg@
@abs_func@(b) > @abs_func@(@MAX@ / a)
#else
b > @MAX@ / a
#endif
) {
return 1;
}
return 0;
#endif
}
/**end repeat**/

static NPY_INLINE int
npy_mul_sizes_with_overflow (npy_intp * r, npy_intp a, npy_intp b)
{
#ifdef HAVE___BUILTIN_MUL_OVERFLOW
return __builtin_mul_overflow(a, b, r);
#else

assert a >= 0 && b >= 0, "this function only supports non-negative numbers"
seberg marked this conversation as resolved.
Show resolved Hide resolved
const npy_intp half_sz = ((npy_intp)1 << ((sizeof(a) * 8 - 1 ) / 2));

*r = a * b;
/*
* avoid expensive division on common no overflow case
*/
if (NPY_UNLIKELY((a | b) >= half_sz)
&& a != 0 && b > NPY_MAX_INTP / a) {
return 1;
}
return 0;
#endif
}

#endif
6 changes: 3 additions & 3 deletions numpy/core/src/multiarray/compiled_base.c
Expand Up @@ -9,7 +9,7 @@
#include "numpy/npy_3kcompat.h"
#include "numpy/npy_math.h"
#include "npy_config.h"
#include "templ_common.h" /* for npy_mul_with_overflow_intp */
#include "templ_common.h" /* for npy_mul_sizes_with_overflow */
#include "lowlevel_strided_loops.h" /* for npy_bswap8 */
#include "alloc.h"
#include "ctors.h"
Expand Down Expand Up @@ -1069,7 +1069,7 @@ arr_ravel_multi_index(PyObject *self, PyObject *args, PyObject *kwds)
s = 1;
for (i = dimensions.len-1; i >= 0; --i) {
ravel_strides[i] = s;
if (npy_mul_with_overflow_intp(&s, s, dimensions.ptr[i])) {
if (npy_mul_sizes_with_overflow(&s, s, dimensions.ptr[i])) {
PyErr_SetString(PyExc_ValueError,
"invalid dims: array size defined by dims is larger "
"than the maximum possible size.");
Expand All @@ -1081,7 +1081,7 @@ arr_ravel_multi_index(PyObject *self, PyObject *args, PyObject *kwds)
s = 1;
for (i = 0; i < dimensions.len; ++i) {
ravel_strides[i] = s;
if (npy_mul_with_overflow_intp(&s, s, dimensions.ptr[i])) {
if (npy_mul_sizes_with_overflow(&s, s, dimensions.ptr[i])) {
PyErr_SetString(PyExc_ValueError,
"invalid dims: array size defined by dims is larger "
"than the maximum possible size.");
Expand Down
6 changes: 3 additions & 3 deletions numpy/core/src/multiarray/ctors.c
Expand Up @@ -27,7 +27,7 @@
#include "datetime_strings.h"
#include "array_assign.h"
#include "mapping.h" /* for array_item_asarray */
#include "templ_common.h" /* for npy_mul_with_overflow_intp */
#include "templ_common.h" /* for npy_mul_sizes_with_overflow */
#include "alloc.h"
#include <assert.h>

Expand Down Expand Up @@ -746,7 +746,7 @@ PyArray_NewFromDescr_int(
* Care needs to be taken to avoid integer overflow when multiplying
* the dimensions together to get the total size of the array.
*/
if (npy_mul_with_overflow_intp(&nbytes, nbytes, fa->dimensions[i])) {
if (npy_mul_sizes_with_overflow(&nbytes, nbytes, fa->dimensions[i])) {
PyErr_SetString(PyExc_ValueError,
"array is too big; `arr.size * arr.dtype.itemsize` "
"is larger than the maximum possible size.");
Expand Down Expand Up @@ -3956,7 +3956,7 @@ PyArray_FromIter(PyObject *obj, PyArray_Descr *dtype, npy_intp count)
be suitable to reuse here.
*/
elcount = (i >> 1) + (i < 4 ? 4 : 2) + i;
if (!npy_mul_with_overflow_intp(&nbytes, elcount, elsize)) {
if (!npy_mul_sizes_with_overflow(&nbytes, elcount, elsize)) {
/* The handler is always valid */
new_data = PyDataMem_UserRENEW(
PyArray_BYTES(ret), nbytes, PyArray_HANDLER(ret));
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/descriptor.c
Expand Up @@ -15,7 +15,7 @@

#include "_datetime.h"
#include "common.h"
#include "templ_common.h" /* for npy_mul_with_overflow_intp */
#include "templ_common.h" /* for npy_mul_sizes_with_overflow */
#include "descriptor.h"
#include "alloc.h"
#include "assert.h"
Expand Down
6 changes: 3 additions & 3 deletions numpy/core/src/multiarray/methods.c
Expand Up @@ -17,7 +17,7 @@
#include "ufunc_override.h"
#include "array_coercion.h"
#include "common.h"
#include "templ_common.h" /* for npy_mul_with_overflow_intp */
#include "templ_common.h" /* for npy_mul_sizes_with_overflow */
#include "ctors.h"
#include "calculation.h"
#include "convert_datatype.h"
Expand Down Expand Up @@ -2053,13 +2053,13 @@ array_setstate(PyArrayObject *self, PyObject *args)
if (dimensions[i] == 0) {
empty = NPY_TRUE;
}
overflowed = npy_mul_with_overflow_intp(
overflowed = npy_mul_sizes_with_overflow(
&nbytes, nbytes, dimensions[i]);
if (overflowed) {
return PyErr_NoMemory();
}
}
overflowed = npy_mul_with_overflow_intp(
overflowed = npy_mul_sizes_with_overflow(
&nbytes, nbytes, PyArray_DESCR(self)->elsize);
if (overflowed) {
return PyErr_NoMemory();
Expand Down
4 changes: 2 additions & 2 deletions numpy/core/src/multiarray/multiarraymodule.c
Expand Up @@ -62,7 +62,7 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
#include "multiarraymodule.h"
#include "cblasfuncs.h"
#include "vdot.h"
#include "templ_common.h" /* for npy_mul_with_overflow_intp */
#include "templ_common.h" /* for npy_mul_sizes_with_overflow */
#include "compiled_base.h"
#include "mem_overlap.h"
#include "typeinfo.h"
Expand Down Expand Up @@ -194,7 +194,7 @@ PyArray_OverflowMultiplyList(npy_intp const *l1, int n)
if (dim == 0) {
return 0;
}
if (npy_mul_with_overflow_intp(&prod, prod, dim)) {
if (npy_mul_sizes_with_overflow(&prod, prod, dim)) {
return -1;
}
}
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/nditer_api.c
Expand Up @@ -132,7 +132,7 @@ NpyIter_RemoveAxis(NpyIter *iter, int axis)
NIT_ITERSIZE(iter) = 1;
axisdata = NIT_AXISDATA(iter);
for (idim = 0; idim < ndim-1; ++idim) {
if (npy_mul_with_overflow_intp(&NIT_ITERSIZE(iter),
if (npy_mul_sizes_with_overflow(&NIT_ITERSIZE(iter),
NIT_ITERSIZE(iter), NAD_SHAPE(axisdata))) {
NIT_ITERSIZE(iter) = -1;
break;
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/nditer_constr.c
Expand Up @@ -1721,7 +1721,7 @@ npyiter_fill_axisdata(NpyIter *iter, npy_uint32 flags, npyiter_opitflags *op_itf
/* Now fill in the ITERSIZE member */
NIT_ITERSIZE(iter) = 1;
for (idim = 0; idim < ndim; ++idim) {
if (npy_mul_with_overflow_intp(&NIT_ITERSIZE(iter),
if (npy_mul_sizes_with_overflow(&NIT_ITERSIZE(iter),
NIT_ITERSIZE(iter), broadcast_shape[idim])) {
if ((itflags & NPY_ITFLAG_HASMULTIINDEX) &&
!(itflags & NPY_ITFLAG_HASINDEX) &&
Expand Down
8 changes: 4 additions & 4 deletions numpy/core/src/multiarray/shape.c
Expand Up @@ -19,7 +19,7 @@
#include "shape.h"

#include "multiarraymodule.h" /* for interned strings */
#include "templ_common.h" /* for npy_mul_with_overflow_intp */
#include "templ_common.h" /* for npy_mul_sizes_with_overflow */
#include "common.h" /* for convert_shape_to_string */
#include "alloc.h"

Expand Down Expand Up @@ -71,15 +71,15 @@ PyArray_Resize(PyArrayObject *self, PyArray_Dims *newshape, int refcheck,
"negative dimensions not allowed");
return NULL;
}
if (npy_mul_with_overflow_intp(&newsize, newsize, new_dimensions[k])) {
if (npy_mul_sizes_with_overflow(&newsize, newsize, new_dimensions[k])) {
return PyErr_NoMemory();
}
}

/* Convert to number of bytes. The new count might overflow */
elsize = PyArray_DESCR(self)->elsize;
oldnbytes = oldsize * elsize;
if (npy_mul_with_overflow_intp(&newnbytes, newsize, elsize)) {
if (npy_mul_sizes_with_overflow(&newnbytes, newsize, elsize)) {
return PyErr_NoMemory();
}

Expand Down Expand Up @@ -498,7 +498,7 @@ _fix_unknown_dimension(PyArray_Dims *newshape, PyArrayObject *arr)
return -1;
}
}
else if (npy_mul_with_overflow_intp(&s_known, s_known,
else if (npy_mul_sizes_with_overflow(&s_known, s_known,
dimensions[i])) {
raise_reshape_size_mismatch(newshape, arr);
return -1;
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/textreading/growth.c
Expand Up @@ -39,7 +39,7 @@ grow_size_and_multiply(npy_intp *size, npy_intp min_grow, npy_intp itemsize) {
}
*size = (npy_intp)new_size;
npy_intp alloc_size;
if (npy_mul_with_overflow_intp(&alloc_size, (npy_intp)new_size, itemsize)) {
if (npy_mul_sizes_with_overflow(&alloc_size, (npy_intp)new_size, itemsize)) {
return -1;
}
return alloc_size;
Expand Down
3 changes: 1 addition & 2 deletions numpy/core/tests/test_scalarmath.py
Expand Up @@ -902,8 +902,7 @@ def test_scalar_integer_operation_overflow(dtype, operation):
@pytest.mark.parametrize("operation", [
lambda min, neg_1: -min,
lambda min, neg_1: abs(min),
pytest.param(lambda min, neg_1: min * neg_1,
marks=pytest.mark.xfail(reason="broken on some platforms")),
lambda min, neg_1: min * neg_1,
pytest.param(lambda min, neg_1: min // neg_1,
marks=pytest.mark.skip(reason="broken on some platforms"))],
ids=["neg", "abs", "*", "//"])
Expand Down