Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PERF: Optimize array check for bounded 0,1 values #20798

Merged
merged 1 commit into from Jan 11, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 20 additions & 3 deletions numpy/random/_common.pyx
Expand Up @@ -5,6 +5,7 @@ from cpython cimport PyFloat_AsDouble
import sys
import numpy as np
cimport numpy as np
cimport numpy.math as npmath

from libc.stdint cimport uintptr_t

Expand Down Expand Up @@ -343,6 +344,24 @@ cdef object float_fill_from_double(void *func, bitgen_t *state, object size, obj
out_array_data[i] = <float>random_func(state)
return out_array

cdef int _check_array_cons_bounded_0_1(np.ndarray val, object name) except -1:
cdef double *val_data
cdef np.npy_intp i
cdef bint err = 0

if not np.PyArray_ISONESEGMENT(val) or np.PyArray_TYPE(val) != np.NPY_DOUBLE:
# slow path for non-contiguous arrays or any non-double dtypes
err = not np.all(np.greater_equal(val, 0)) or not np.all(np.less_equal(val, 1))
else:
val_data = <double *>np.PyArray_DATA(val)
for i in range(np.PyArray_SIZE(val)):
err = (not (val_data[i] >= 0)) or (not val_data[i] <= 1)
if err:
break
if err:
raise ValueError(f"{name} < 0, {name} > 1 or {name} contains NaNs")

return 0

cdef int check_array_constraint(np.ndarray val, object name, constraint_type cons) except -1:
if cons == CONS_NON_NEGATIVE:
Expand All @@ -354,9 +373,7 @@ cdef int check_array_constraint(np.ndarray val, object name, constraint_type con
elif np.any(np.less_equal(val, 0)):
raise ValueError(name + " <= 0")
elif cons == CONS_BOUNDED_0_1:
if not np.all(np.greater_equal(val, 0)) or \
not np.all(np.less_equal(val, 1)):
raise ValueError("{0} < 0, {0} > 1 or {0} contains NaNs".format(name))
return _check_array_cons_bounded_0_1(val, name)
elif cons == CONS_BOUNDED_GT_0_1:
if not np.all(np.greater(val, 0)) or not np.all(np.less_equal(val, 1)):
raise ValueError("{0} <= 0, {0} > 1 or {0} contains NaNs".format(name))
Expand Down