Skip to content

Commit

Permalink
Merge pull request #20798 from charris/backport-20643
Browse files Browse the repository at this point in the history
PERF: Optimize array check for bounded 0,1 values
  • Loading branch information
charris committed Jan 11, 2022
2 parents 8f87be6 + b059f22 commit 115cf09
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions numpy/random/_common.pyx
Expand Up @@ -5,6 +5,7 @@ from cpython cimport PyFloat_AsDouble
import sys
import numpy as np
cimport numpy as np
cimport numpy.math as npmath

from libc.stdint cimport uintptr_t

Expand Down Expand Up @@ -343,6 +344,24 @@ cdef object float_fill_from_double(void *func, bitgen_t *state, object size, obj
out_array_data[i] = <float>random_func(state)
return out_array

cdef int _check_array_cons_bounded_0_1(np.ndarray val, object name) except -1:
cdef double *val_data
cdef np.npy_intp i
cdef bint err = 0

if not np.PyArray_ISONESEGMENT(val) or np.PyArray_TYPE(val) != np.NPY_DOUBLE:
# slow path for non-contiguous arrays or any non-double dtypes
err = not np.all(np.greater_equal(val, 0)) or not np.all(np.less_equal(val, 1))
else:
val_data = <double *>np.PyArray_DATA(val)
for i in range(np.PyArray_SIZE(val)):
err = (not (val_data[i] >= 0)) or (not val_data[i] <= 1)
if err:
break
if err:
raise ValueError(f"{name} < 0, {name} > 1 or {name} contains NaNs")

return 0

cdef int check_array_constraint(np.ndarray val, object name, constraint_type cons) except -1:
if cons == CONS_NON_NEGATIVE:
Expand All @@ -354,9 +373,7 @@ cdef int check_array_constraint(np.ndarray val, object name, constraint_type con
elif np.any(np.less_equal(val, 0)):
raise ValueError(name + " <= 0")
elif cons == CONS_BOUNDED_0_1:
if not np.all(np.greater_equal(val, 0)) or \
not np.all(np.less_equal(val, 1)):
raise ValueError("{0} < 0, {0} > 1 or {0} contains NaNs".format(name))
return _check_array_cons_bounded_0_1(val, name)
elif cons == CONS_BOUNDED_GT_0_1:
if not np.all(np.greater(val, 0)) or not np.all(np.less_equal(val, 1)):
raise ValueError("{0} <= 0, {0} > 1 or {0} contains NaNs".format(name))
Expand Down

0 comments on commit 115cf09

Please sign in to comment.