Skip to content

Commit

Permalink
Add methods weight, weights, and total_weight to weighted_index.rs (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelOwenDyer committed Apr 2, 2024
1 parent 0518975 commit 7b37c15
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -9,6 +9,7 @@ A [separate changelog is kept for rand_core](rand_core/CHANGELOG.md).
You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.html) useful.

## [Unreleased]
- Add `rand::distributions::WeightedIndex::{weight, weights, total_weight}` (#1420)
- Bump the MSRV to 1.61.0

## [0.9.0-alpha.1] - 2024-03-18
Expand Down
188 changes: 188 additions & 0 deletions src/distributions/weighted_index.rs
Expand Up @@ -15,6 +15,7 @@ use core::fmt;

// Note that this whole module is only imported if feature="alloc" is enabled.
use alloc::vec::Vec;
use core::fmt::Debug;

#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -243,6 +244,124 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
}
}

/// A lazy-loading iterator over the weights of a `WeightedIndex` distribution.
/// This is returned by [`WeightedIndex::weights`].
pub struct WeightedIndexIter<'a, X: SampleUniform + PartialOrd> {
weighted_index: &'a WeightedIndex<X>,
index: usize,
}

impl<'a, X> Debug for WeightedIndexIter<'a, X>
where
X: SampleUniform + PartialOrd + Debug,
X::Sampler: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WeightedIndexIter")
.field("weighted_index", &self.weighted_index)
.field("index", &self.index)
.finish()
}
}

impl<'a, X> Clone for WeightedIndexIter<'a, X>
where
X: SampleUniform + PartialOrd,
{
fn clone(&self) -> Self {
WeightedIndexIter {
weighted_index: self.weighted_index,
index: self.index,
}
}
}

impl<'a, X> Iterator for WeightedIndexIter<'a, X>
where
X: for<'b> ::core::ops::SubAssign<&'b X>
+ SampleUniform
+ PartialOrd
+ Clone,
{
type Item = X;

fn next(&mut self) -> Option<Self::Item> {
match self.weighted_index.weight(self.index) {
None => None,
Some(weight) => {
self.index += 1;
Some(weight)
}
}
}
}

impl<X: SampleUniform + PartialOrd + Clone> WeightedIndex<X> {
/// Returns the weight at the given index, if it exists.
///
/// If the index is out of bounds, this will return `None`.
///
/// # Example
///
/// ```
/// use rand::distributions::WeightedIndex;
///
/// let weights = [0, 1, 2];
/// let dist = WeightedIndex::new(&weights).unwrap();
/// assert_eq!(dist.weight(0), Some(0));
/// assert_eq!(dist.weight(1), Some(1));
/// assert_eq!(dist.weight(2), Some(2));
/// assert_eq!(dist.weight(3), None);
/// ```
pub fn weight(&self, index: usize) -> Option<X>
where
X: for<'a> ::core::ops::SubAssign<&'a X>
{
let mut weight = if index < self.cumulative_weights.len() {
self.cumulative_weights[index].clone()
} else if index == self.cumulative_weights.len() {
self.total_weight.clone()
} else {
return None;
};
if index > 0 {
weight -= &self.cumulative_weights[index - 1];
}
Some(weight)
}

/// Returns a lazy-loading iterator containing the current weights of this distribution.
///
/// If this distribution has not been updated since its creation, this will return the
/// same weights as were passed to `new`.
///
/// # Example
///
/// ```
/// use rand::distributions::WeightedIndex;
///
/// let weights = [1, 2, 3];
/// let mut dist = WeightedIndex::new(&weights).unwrap();
/// assert_eq!(dist.weights().collect::<Vec<_>>(), vec![1, 2, 3]);
/// dist.update_weights(&[(0, &2)]).unwrap();
/// assert_eq!(dist.weights().collect::<Vec<_>>(), vec![2, 2, 3]);
/// ```
pub fn weights(&self) -> WeightedIndexIter<'_, X>
where
X: for<'a> ::core::ops::SubAssign<&'a X>
{
WeightedIndexIter {
weighted_index: self,
index: 0,
}
}

/// Returns the sum of all weights in this distribution.
pub fn total_weight(&self) -> X {
self.total_weight.clone()
}
}

impl<X> Distribution<usize> for WeightedIndex<X>
where
X: SampleUniform + PartialOrd,
Expand Down Expand Up @@ -458,6 +577,75 @@ mod test {
}
}

#[test]
fn test_update_weights_errors() {
let data = [
(
&[1i32, 0, 0][..],
&[(0, &0)][..],
WeightError::InsufficientNonZero,
),
(
&[10, 10, 10, 10][..],
&[(1, &-11)][..],
WeightError::InvalidWeight, // A weight is negative
),
(
&[1, 2, 3, 4, 5][..],
&[(1, &5), (0, &5)][..], // Wrong order
WeightError::InvalidInput,
),
(
&[1][..],
&[(1, &1)][..], // Index too large
WeightError::InvalidInput,
),
];

for (weights, update, err) in data.iter() {
let total_weight = weights.iter().sum::<i32>();
let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
assert_eq!(distr.total_weight, total_weight);
match distr.update_weights(update) {
Ok(_) => panic!("Expected update_weights to fail, but it succeeded"),
Err(e) => assert_eq!(e, *err),
}
}
}

#[test]
fn test_weight_at() {
let data = [
&[1][..],
&[10, 2, 3, 4][..],
&[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
&[u32::MAX][..],
];

for weights in data.iter() {
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
for (i, weight) in weights.iter().enumerate() {
assert_eq!(distr.weight(i), Some(*weight));
}
assert_eq!(distr.weight(weights.len()), None);
}
}

#[test]
fn test_weights() {
let data = [
&[1][..],
&[10, 2, 3, 4][..],
&[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
&[u32::MAX][..],
];

for weights in data.iter() {
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
assert_eq!(distr.weights().collect::<Vec<_>>(), weights.to_vec());
}
}

#[test]
fn value_stability() {
fn test_samples<X: Weight + SampleUniform + PartialOrd, I>(
Expand Down

0 comments on commit 7b37c15

Please sign in to comment.