Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve calculation of the scale parameter for the uniform float distribution. #1301

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
215 changes: 203 additions & 12 deletions src/distributions/uniform.rs
Expand Up @@ -826,12 +826,109 @@ pub struct UniformFloat<X> {
scale: X,
}

trait Summable<T> {
fn compensated_sum(&self) -> T;
}

trait ScaleComputable<T> {
fn compute_scale(low: T, high: T) -> T;
}
dhardy marked this conversation as resolved.
Show resolved Hide resolved

macro_rules! uniform_float_impl {
($ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => {
impl SampleUniform for $ty {
type Sampler = UniformFloat<$ty>;
}

impl Summable<$ty> for &[$ty] {
fn compensated_sum(&self) -> $ty {
// Kahan compensated sum
let mut sum = <$ty>::splat(0.0);
let mut c = <$ty>::splat(0.0);
for val in *self {
let y = val - c;
let t = sum + y;
c = (t - sum) - y;
sum = t;
}
sum
}
}

impl ScaleComputable<$ty> for $ty {
fn compute_scale(low: $ty, high: $ty) -> $ty {
let eps = <$ty>::splat($f_scalar::EPSILON);

// `max_rand` is 1.0 - eps. This is actually the second largest
// float less than 1.0, because the spacing of the floats in the
// interval [0.5, 1.0) is `eps/2`.
let max_rand = <$ty>::splat(
(::core::$u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
);

// `delta_high` is half the distance from `high` to the largest
// float that is less than `high`. If `high` is subnormal or 0,
// then `delta_high` will be 0. Why this is needed is explained
// below.
let delta_high = <$ty>::splat(0.5) * (high - high.utils_next_down());

// We want `scale * max_rand + low < high`. Let `high_1` be the
// (hypothetical) float that is the midpoint between `high` and
// the largest float less than `high`. The midpoint is used
// because any float calculation that would result in a value in
// `(high_1, high)` would be rounded to `high`. The ideal
// condition for upper bound of `scale` is then
// scale * max_rand + low = high_1`
// or
// scale = (high_1 - low)/max_rand
//
// Write `high_1 = high - delta_high`, `max_rand = 1 - eps`,
// and approximate `1/(1 - eps)` as `(1 + eps)`. Then we have
//
// scale = (high - delta_high - low)*(1 + eps)
// = high - low + eps*high - eps*low - delta_high
//
// (The extremely small term `-delta_high*eps` has been ignored.)
// The following uses Kahan's compensated summation to compute `scale`
// from those terms.
let terms: &[$ty] = &[high, -low, eps * high, -eps * low, -delta_high];
let mut scale = terms.compensated_sum();

// Empirical tests show that `scale` is generally within 1 or 2 ULPs
// of the "ideal" scale. Next we adjust `scale`, if necessary, to
// the ideal value.

// Check that `scale * max_rand + low` is less than `high`. If it is
// not, repeatedly adjust `scale` down by one ULP until the condition
// is satisfied. Generally this requires 0 or 1 adjustments to `scale`.
// (The original `too_big_mask` is saved so we can use it again below.)
let too_big_mask = (scale * max_rand + low).ge_mask(high);
loop {
let mask = (scale * max_rand + low).ge_mask(high);
if !mask.any() {
break;
}
scale = scale.decrease_masked(mask);
}
// We have ensured that `scale * max_rand + low < high`. Now see if
// we can increase `scale` and still maintain that inequality. We
// only need to do this if `scale` was not initially too big.
let not_too_big_mask = !too_big_mask;
let mut mask = not_too_big_mask;
if mask.any() {
loop {
let next_scale = scale.increase_masked(mask);
mask = (next_scale * max_rand + low).lt_mask(high) & not_too_big_mask;
if !mask.any() {
break;
}
scale = scale.increase_masked(mask);
}
}
scale
}
}

impl UniformSampler for UniformFloat<$ty> {
type X = $ty;

Expand All @@ -849,22 +946,12 @@ macro_rules! uniform_float_impl {
if !(low.all_lt(high)) {
return Err(Error::EmptyRange);
}
let max_rand = <$ty>::splat(
(::core::$u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
);

let mut scale = high - low;
if !(scale.all_finite()) {
if !((high - low).all_finite()) {
return Err(Error::NonFinite);
}

loop {
let mask = (scale * max_rand + low).ge_mask(high);
if !mask.any() {
break;
}
scale = scale.decrease_masked(mask);
}
let scale = <$ty>::compute_scale(low, high);

debug_assert!(<$ty>::splat(0.0).all_le(scale));

Expand Down Expand Up @@ -1430,6 +1517,110 @@ mod tests {
}
}

macro_rules! compute_scale_scalar_tests {
($($fname:ident: $ty:ident,)*) => {
$(
#[test]
fn $fname() {
// For each `(low, high)` pair in `v`, compute `scale` and
// verify that
// scale * max_rand + low < high
// and
// next_up(scale) * max_rand + low >= high
let eps = $ty::EPSILON;
let v = [
(0.0 as $ty, 100.0 as $ty),
(-0.125 as $ty, 0.0 as $ty),
(0.0 as $ty, 0.125 as $ty),
(-1.5 as $ty, -0.0 as $ty),
(-0.0 as $ty, 1.5 as $ty),
(-1.0 as $ty, -0.875 as $ty),
(-1e35 as $ty, -1e25 as $ty),
(1e-35 as $ty, 1e-25 as $ty),
(-1e35 as $ty, 1e35 as $ty),
// This one is mainly for f64, because these particular
// numbers provided an example where the initial calculation
// of `scale` is too small:
(11.0 as $ty, 1001.000000005 as $ty),
// Very small intervals--the difference `high - low` is
// a small to moderate multiple of the type's EPSILON.
(1.0 as $ty - (11.5 as $ty) * eps, 1.0 as $ty - (0.5 as $ty) * eps),
(1.0 as $ty - (196389.0 as $ty) * eps / (2.0 as $ty), 1.0 as $ty),
(1.0 as $ty, 1.0 as $ty + (1.0 as $ty) * eps),
(1.0 as $ty, 1.0 as $ty + (2.0 as $ty) * eps),
(1.0 as $ty - eps, 1.0 as $ty),
(1.0 as $ty - eps, 1.0 as $ty + (2.0 as $ty) * eps),
(-1.0 as $ty, -1.0 as $ty + (2.0 as $ty) * eps),
(-2.0 as $ty, -2.0 as $ty + (17.0 as $ty) * eps),
(-11.0 as $ty, -11.0 as $ty + (68.0 as $ty) *eps),
// Ridiculously small intervals: `low` and `high` are subnormal.
(-$ty::from_bits(3), $ty::from_bits(8)),
(-$ty::from_bits(5), -$ty::from_bits(1)),
// `high - low` is a significant fraction of the type's MAX.
((0.5 as $ty) * $ty::MIN, (0.25 as $ty) * $ty::MAX),
((0.25 as $ty) * $ty::MIN, (0.5 as $ty) * $ty::MAX),
((0.5 as $ty) * $ty::MIN, (0.4999995 as $ty) * $ty::MAX),
((0.75 as $ty) * $ty::MIN, 0.0 as $ty),
(0.0 as $ty, (0.75 as $ty) * $ty::MAX),
];
let max_rand = 1.0 as $ty - eps;
for (low, high) in v {
let scale = <$ty>::compute_scale(low, high);
assert!(scale > 0.0 as $ty);
assert!(scale * max_rand + low < high);
// let next_scale = scale.next_up();
let next_scale = <$ty>::from_bits(scale.to_bits() + 1);
assert!(next_scale * max_rand + low >= high);
}
}
)*
}
}

compute_scale_scalar_tests! {
test_compute_scale_scalar_f32: f32,
test_compute_scale_scalar_f64: f64,
}

macro_rules! compute_scale_simd_tests {
($($fname:ident: ($ty:ty, $f_scalar:ident),)*) => {
$(
#[test]
#[cfg(feature = "simd_support")]
fn $fname() {
let low_vals = [0.0 as $f_scalar, -1e35 as $f_scalar, 1e-7 as $f_scalar, -13.25 as $f_scalar];
let high_vals = [1.5 as $f_scalar, 3e35 as $f_scalar, 1.125 as $f_scalar, -0.0 as $f_scalar];
// Test that the vector version gives the same results as
// the scalar version. Create test vectors that use two
// values at a time from arrays `low_vals` and `high_vals`.
for k in 0..(low_vals.len() - 1) {
let c1 = <$ty>::splat(low_vals[k]);
let c2 = <$ty>::splat(low_vals[k + 1]);
let (low, _) = c1.interleave(c2);
let c1 = <$ty>::splat(high_vals[k]);
let c2 = <$ty>::splat(high_vals[k + 1]);
let (high, _) = c1.interleave(c2);
let scale = <$ty>::compute_scale(low, high);
for i in 0..<$ty>::LANES {
assert_eq!(scale.extract(i), <$f_scalar>::compute_scale(low.extract(i), high.extract(i)));
}
}
}
)*
}
}

compute_scale_simd_tests! {
test_compute_scale_f32x2: (f32x2, f32),
test_compute_scale_f32x4: (f32x4, f32),
test_compute_scale_f32x8: (f32x8, f32),
test_compute_scale_f32x16: (f32x16, f32),
test_compute_scale_f64x2: (f64x2, f64),
test_compute_scale_f64x4: (f64x4, f64),
test_compute_scale_f64x8: (f64x8, f64),
}


#[test]
fn test_float_overflow() {
assert_eq!(Uniform::try_from(::core::f64::MIN..::core::f64::MAX), Err(Error::NonFinite));
Expand Down