Skip to content

Commit

Permalink
ENH,MAINT: Add overflow handling for negative integers scalar multipl…
Browse files Browse the repository at this point in the history
…ication (#21793)

Remane `npy_mul_with_overflow_intp` to `npy_mul_sizes_with_overflow` as it only allows positive numbers.  Then introduce new versions for all integers (not intp) to use it for the integer scalar math with overflow detection.
(It is OK to use everywhere, just for sizes we know they will be positive normally.)

Related to #21506

Follow-up to #21648
  • Loading branch information
Micky774 committed Jul 17, 2022
1 parent 6b8d55e commit 4156ae2
Show file tree
Hide file tree
Showing 12 changed files with 57 additions and 28 deletions.
2 changes: 1 addition & 1 deletion numpy/core/src/common/npy_hashtable.c
Expand Up @@ -146,7 +146,7 @@ _resize_if_necessary(PyArrayIdentityHash *tb)
}

npy_intp alloc_size;
if (npy_mul_with_overflow_intp(&alloc_size, new_size, tb->key_len + 1)) {
if (npy_mul_sizes_with_overflow(&alloc_size, new_size, tb->key_len + 1)) {
return -1;
}
tb->buckets = PyMem_Calloc(alloc_size, sizeof(PyObject *));
Expand Down
42 changes: 36 additions & 6 deletions numpy/core/src/common/templ_common.h.src
Expand Up @@ -7,13 +7,15 @@

/**begin repeat
* #name = int, uint, long, ulong,
* longlong, ulonglong, intp#
* longlong, ulonglong#
* #type = npy_int, npy_uint, npy_long, npy_ulong,
* npy_longlong, npy_ulonglong, npy_intp#
* npy_longlong, npy_ulonglong#
* #MAX = NPY_MAX_INT, NPY_MAX_UINT, NPY_MAX_LONG, NPY_MAX_ULONG,
* NPY_MAX_LONGLONG, NPY_MAX_ULONGLONG, NPY_MAX_INTP#
* NPY_MAX_LONGLONG, NPY_MAX_ULONGLONG#
*
* #neg = (1,0)*3#
* #abs_func = abs*2, labs*2, llabs*2#
*/

/*
* writes result of a * b into r
* returns 1 if a * b overflowed else returns 0
Expand All @@ -34,13 +36,41 @@ npy_mul_with_overflow_@name@(@type@ * r, @type@ a, @type@ b)
/*
* avoid expensive division on common no overflow case
*/
if (NPY_UNLIKELY((a | b) >= half_sz) &&
a != 0 && b > @MAX@ / a) {
if ((NPY_UNLIKELY((a | b) >= half_sz) || (a | b) < 0) &&
a != 0 &&
#if @neg@
@abs_func@(b) > @abs_func@(@MAX@ / a)
#else
b > @MAX@ / a
#endif
) {
return 1;
}
return 0;
#endif
}
/**end repeat**/

static NPY_INLINE int
npy_mul_sizes_with_overflow (npy_intp * r, npy_intp a, npy_intp b)
{
#ifdef HAVE___BUILTIN_MUL_OVERFLOW
return __builtin_mul_overflow(a, b, r);
#else

assert(a >= 0 && b >= 0, "this function only supports non-negative numbers");
const npy_intp half_sz = ((npy_intp)1 << ((sizeof(a) * 8 - 1 ) / 2));

*r = a * b;
/*
* avoid expensive division on common no overflow case
*/
if (NPY_UNLIKELY((a | b) >= half_sz)
&& a != 0 && b > NPY_MAX_INTP / a) {
return 1;
}
return 0;
#endif
}

#endif
6 changes: 3 additions & 3 deletions numpy/core/src/multiarray/compiled_base.c
Expand Up @@ -9,7 +9,7 @@
#include "numpy/npy_3kcompat.h"
#include "numpy/npy_math.h"
#include "npy_config.h"
#include "templ_common.h" /* for npy_mul_with_overflow_intp */
#include "templ_common.h" /* for npy_mul_sizes_with_overflow */
#include "lowlevel_strided_loops.h" /* for npy_bswap8 */
#include "alloc.h"
#include "ctors.h"
Expand Down Expand Up @@ -1069,7 +1069,7 @@ arr_ravel_multi_index(PyObject *self, PyObject *args, PyObject *kwds)
s = 1;
for (i = dimensions.len-1; i >= 0; --i) {
ravel_strides[i] = s;
if (npy_mul_with_overflow_intp(&s, s, dimensions.ptr[i])) {
if (npy_mul_sizes_with_overflow(&s, s, dimensions.ptr[i])) {
PyErr_SetString(PyExc_ValueError,
"invalid dims: array size defined by dims is larger "
"than the maximum possible size.");
Expand All @@ -1081,7 +1081,7 @@ arr_ravel_multi_index(PyObject *self, PyObject *args, PyObject *kwds)
s = 1;
for (i = 0; i < dimensions.len; ++i) {
ravel_strides[i] = s;
if (npy_mul_with_overflow_intp(&s, s, dimensions.ptr[i])) {
if (npy_mul_sizes_with_overflow(&s, s, dimensions.ptr[i])) {
PyErr_SetString(PyExc_ValueError,
"invalid dims: array size defined by dims is larger "
"than the maximum possible size.");
Expand Down
6 changes: 3 additions & 3 deletions numpy/core/src/multiarray/ctors.c
Expand Up @@ -27,7 +27,7 @@
#include "datetime_strings.h"
#include "array_assign.h"
#include "mapping.h" /* for array_item_asarray */
#include "templ_common.h" /* for npy_mul_with_overflow_intp */
#include "templ_common.h" /* for npy_mul_sizes_with_overflow */
#include "alloc.h"
#include <assert.h>

Expand Down Expand Up @@ -746,7 +746,7 @@ PyArray_NewFromDescr_int(
* Care needs to be taken to avoid integer overflow when multiplying
* the dimensions together to get the total size of the array.
*/
if (npy_mul_with_overflow_intp(&nbytes, nbytes, fa->dimensions[i])) {
if (npy_mul_sizes_with_overflow(&nbytes, nbytes, fa->dimensions[i])) {
PyErr_SetString(PyExc_ValueError,
"array is too big; `arr.size * arr.dtype.itemsize` "
"is larger than the maximum possible size.");
Expand Down Expand Up @@ -3956,7 +3956,7 @@ PyArray_FromIter(PyObject *obj, PyArray_Descr *dtype, npy_intp count)
be suitable to reuse here.
*/
elcount = (i >> 1) + (i < 4 ? 4 : 2) + i;
if (!npy_mul_with_overflow_intp(&nbytes, elcount, elsize)) {
if (!npy_mul_sizes_with_overflow(&nbytes, elcount, elsize)) {
/* The handler is always valid */
new_data = PyDataMem_UserRENEW(
PyArray_BYTES(ret), nbytes, PyArray_HANDLER(ret));
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/descriptor.c
Expand Up @@ -15,7 +15,7 @@

#include "_datetime.h"
#include "common.h"
#include "templ_common.h" /* for npy_mul_with_overflow_intp */
#include "templ_common.h" /* for npy_mul_sizes_with_overflow */
#include "descriptor.h"
#include "alloc.h"
#include "assert.h"
Expand Down
6 changes: 3 additions & 3 deletions numpy/core/src/multiarray/methods.c
Expand Up @@ -17,7 +17,7 @@
#include "ufunc_override.h"
#include "array_coercion.h"
#include "common.h"
#include "templ_common.h" /* for npy_mul_with_overflow_intp */
#include "templ_common.h" /* for npy_mul_sizes_with_overflow */
#include "ctors.h"
#include "calculation.h"
#include "convert_datatype.h"
Expand Down Expand Up @@ -2053,13 +2053,13 @@ array_setstate(PyArrayObject *self, PyObject *args)
if (dimensions[i] == 0) {
empty = NPY_TRUE;
}
overflowed = npy_mul_with_overflow_intp(
overflowed = npy_mul_sizes_with_overflow(
&nbytes, nbytes, dimensions[i]);
if (overflowed) {
return PyErr_NoMemory();
}
}
overflowed = npy_mul_with_overflow_intp(
overflowed = npy_mul_sizes_with_overflow(
&nbytes, nbytes, PyArray_DESCR(self)->elsize);
if (overflowed) {
return PyErr_NoMemory();
Expand Down
4 changes: 2 additions & 2 deletions numpy/core/src/multiarray/multiarraymodule.c
Expand Up @@ -62,7 +62,7 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
#include "multiarraymodule.h"
#include "cblasfuncs.h"
#include "vdot.h"
#include "templ_common.h" /* for npy_mul_with_overflow_intp */
#include "templ_common.h" /* for npy_mul_sizes_with_overflow */
#include "compiled_base.h"
#include "mem_overlap.h"
#include "typeinfo.h"
Expand Down Expand Up @@ -194,7 +194,7 @@ PyArray_OverflowMultiplyList(npy_intp const *l1, int n)
if (dim == 0) {
return 0;
}
if (npy_mul_with_overflow_intp(&prod, prod, dim)) {
if (npy_mul_sizes_with_overflow(&prod, prod, dim)) {
return -1;
}
}
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/nditer_api.c
Expand Up @@ -132,7 +132,7 @@ NpyIter_RemoveAxis(NpyIter *iter, int axis)
NIT_ITERSIZE(iter) = 1;
axisdata = NIT_AXISDATA(iter);
for (idim = 0; idim < ndim-1; ++idim) {
if (npy_mul_with_overflow_intp(&NIT_ITERSIZE(iter),
if (npy_mul_sizes_with_overflow(&NIT_ITERSIZE(iter),
NIT_ITERSIZE(iter), NAD_SHAPE(axisdata))) {
NIT_ITERSIZE(iter) = -1;
break;
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/nditer_constr.c
Expand Up @@ -1721,7 +1721,7 @@ npyiter_fill_axisdata(NpyIter *iter, npy_uint32 flags, npyiter_opitflags *op_itf
/* Now fill in the ITERSIZE member */
NIT_ITERSIZE(iter) = 1;
for (idim = 0; idim < ndim; ++idim) {
if (npy_mul_with_overflow_intp(&NIT_ITERSIZE(iter),
if (npy_mul_sizes_with_overflow(&NIT_ITERSIZE(iter),
NIT_ITERSIZE(iter), broadcast_shape[idim])) {
if ((itflags & NPY_ITFLAG_HASMULTIINDEX) &&
!(itflags & NPY_ITFLAG_HASINDEX) &&
Expand Down
8 changes: 4 additions & 4 deletions numpy/core/src/multiarray/shape.c
Expand Up @@ -19,7 +19,7 @@
#include "shape.h"

#include "multiarraymodule.h" /* for interned strings */
#include "templ_common.h" /* for npy_mul_with_overflow_intp */
#include "templ_common.h" /* for npy_mul_sizes_with_overflow */
#include "common.h" /* for convert_shape_to_string */
#include "alloc.h"

Expand Down Expand Up @@ -71,15 +71,15 @@ PyArray_Resize(PyArrayObject *self, PyArray_Dims *newshape, int refcheck,
"negative dimensions not allowed");
return NULL;
}
if (npy_mul_with_overflow_intp(&newsize, newsize, new_dimensions[k])) {
if (npy_mul_sizes_with_overflow(&newsize, newsize, new_dimensions[k])) {
return PyErr_NoMemory();
}
}

/* Convert to number of bytes. The new count might overflow */
elsize = PyArray_DESCR(self)->elsize;
oldnbytes = oldsize * elsize;
if (npy_mul_with_overflow_intp(&newnbytes, newsize, elsize)) {
if (npy_mul_sizes_with_overflow(&newnbytes, newsize, elsize)) {
return PyErr_NoMemory();
}

Expand Down Expand Up @@ -498,7 +498,7 @@ _fix_unknown_dimension(PyArray_Dims *newshape, PyArrayObject *arr)
return -1;
}
}
else if (npy_mul_with_overflow_intp(&s_known, s_known,
else if (npy_mul_sizes_with_overflow(&s_known, s_known,
dimensions[i])) {
raise_reshape_size_mismatch(newshape, arr);
return -1;
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/textreading/growth.c
Expand Up @@ -39,7 +39,7 @@ grow_size_and_multiply(npy_intp *size, npy_intp min_grow, npy_intp itemsize) {
}
*size = (npy_intp)new_size;
npy_intp alloc_size;
if (npy_mul_with_overflow_intp(&alloc_size, (npy_intp)new_size, itemsize)) {
if (npy_mul_sizes_with_overflow(&alloc_size, (npy_intp)new_size, itemsize)) {
return -1;
}
return alloc_size;
Expand Down
3 changes: 1 addition & 2 deletions numpy/core/tests/test_scalarmath.py
Expand Up @@ -902,8 +902,7 @@ def test_scalar_integer_operation_overflow(dtype, operation):
@pytest.mark.parametrize("operation", [
lambda min, neg_1: -min,
lambda min, neg_1: abs(min),
pytest.param(lambda min, neg_1: min * neg_1,
marks=pytest.mark.xfail(reason="broken on some platforms")),
lambda min, neg_1: min * neg_1,
pytest.param(lambda min, neg_1: min // neg_1,
marks=pytest.mark.skip(reason="broken on some platforms"))],
ids=["neg", "abs", "*", "//"])
Expand Down

0 comments on commit 4156ae2

Please sign in to comment.