Skip to content

Commit

Permalink
Add const-generic par_array_windows
Browse files Browse the repository at this point in the history
  • Loading branch information
cuviper committed Aug 22, 2022
1 parent 9bbb18b commit 831a9fb
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 3 deletions.
66 changes: 65 additions & 1 deletion src/slice/array.rs
Expand Up @@ -3,7 +3,7 @@
use crate::iter::plumbing::*;
use crate::iter::*;

use super::{Iter, IterMut};
use super::{Iter, IterMut, ParallelSlice};

/// Parallel iterator over immutable non-overlapping chunks of a slice
#[derive(Debug)]
Expand Down Expand Up @@ -172,3 +172,67 @@ impl<'data, T: Send + 'data, const N: usize> IndexedParallelIterator
self.iter.with_producer(callback)
}
}

/// Parallel iterator over immutable overlapping windows of a slice
#[derive(Debug)]
pub struct ArrayWindows<'data, T: Sync, const N: usize> {
slice: &'data [T],
}

impl<'data, T: Sync, const N: usize> ArrayWindows<'data, T, N> {
pub(super) fn new(slice: &'data [T]) -> Self {
ArrayWindows { slice }
}
}

impl<'data, T: Sync, const N: usize> Clone for ArrayWindows<'data, T, N> {
fn clone(&self) -> Self {
ArrayWindows { ..*self }
}
}

impl<'data, T: Sync + 'data, const N: usize> ParallelIterator for ArrayWindows<'data, T, N> {
type Item = &'data [T; N];

fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
bridge(self, consumer)
}

fn opt_len(&self) -> Option<usize> {
Some(self.len())
}
}

impl<'data, T: Sync + 'data, const N: usize> IndexedParallelIterator for ArrayWindows<'data, T, N> {
fn drive<C>(self, consumer: C) -> C::Result
where
C: Consumer<Self::Item>,
{
bridge(self, consumer)
}

fn len(&self) -> usize {
assert!(N >= 1);
self.slice.len().saturating_sub(N - 1)
}

fn with_producer<CB>(self, callback: CB) -> CB::Output
where
CB: ProducerCallback<Self::Item>,
{
fn array<T, const N: usize>(slice: &[T]) -> &[T; N] {
debug_assert_eq!(slice.len(), N);
let ptr = slice.as_ptr() as *const [T; N];
unsafe { &*ptr }
}

// FIXME: use our own producer and the standard `array_windows`, rust-lang/rust#75027
self.slice
.par_windows(N)
.map(array::<T, N>)
.with_producer(callback)
}
}
19 changes: 17 additions & 2 deletions src/slice/mod.rs
Expand Up @@ -13,8 +13,8 @@ mod rchunks;

mod test;

#[cfg(min_const_generics)]
pub use self::array::{ArrayChunks, ArrayChunksMut};
#[cfg(has_min_const_generics)]
pub use self::array::{ArrayChunks, ArrayChunksMut, ArrayWindows};

use self::mergesort::par_mergesort;
use self::quicksort::par_quicksort;
Expand Down Expand Up @@ -75,6 +75,21 @@ pub trait ParallelSlice<T: Sync> {
}
}

/// Returns a parallel iterator over all contiguous array windows of
/// length `N`. The windows overlap.
///
/// # Examples
///
/// ```
/// use rayon::prelude::*;
/// let windows: Vec<_> = [1, 2, 3].par_array_windows().collect();
/// assert_eq!(vec![&[1, 2], &[2, 3]], windows);
/// ```
#[cfg(has_min_const_generics)]
fn par_array_windows<const N: usize>(&self) -> ArrayWindows<'_, T, N> {
ArrayWindows::new(self.as_parallel_slice())
}

/// Returns a parallel iterator over at most `chunk_size` elements of
/// `self` at a time. The chunks do not overlap.
///
Expand Down
1 change: 1 addition & 0 deletions tests/clones.rs
Expand Up @@ -106,6 +106,7 @@ fn clone_vec() {
check(v.par_rchunks_exact(42));
check(v.par_array_chunks::<42>());
check(v.par_windows(42));
check(v.par_array_windows::<42>());
check(v.par_split(|x| x % 3 == 0));
check(v.into_par_iter());
}
Expand Down
1 change: 1 addition & 0 deletions tests/debug.rs
Expand Up @@ -130,6 +130,7 @@ fn debug_vec() {
check(v.par_rchunks_mut(42));
check(v.par_rchunks_exact_mut(42));
check(v.par_windows(42));
check(v.par_array_windows::<42>());
check(v.par_split(|x| x % 3 == 0));
check(v.par_split_mut(|x| x % 3 == 0));
check(v.par_drain(..));
Expand Down
9 changes: 9 additions & 0 deletions tests/producer_split_at.rs
Expand Up @@ -328,6 +328,15 @@ fn slice_windows() {
check(&v, || s.par_windows(2));
}

#[test]
fn slice_array_windows() {
use std::convert::TryInto;
let s: Vec<_> = (0..10).collect();
// FIXME: use the standard `array_windows`, rust-lang/rust#75027
let v: Vec<&[_; 2]> = s.windows(2).map(|s| s.try_into().unwrap()).collect();
check(&v, || s.par_array_windows::<2>());
}

#[test]
fn vec() {
let v: Vec<_> = (0..10).collect();
Expand Down

0 comments on commit 831a9fb

Please sign in to comment.