Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TryFlattenUnordered #2577

Merged
merged 11 commits into from
Apr 16, 2022
12 changes: 7 additions & 5 deletions futures-util/src/future/try_select.rs
Expand Up @@ -12,6 +12,9 @@ pub struct TrySelect<A, B> {

impl<A: Unpin, B: Unpin> Unpin for TrySelect<A, B> {}

type EitherOk<A, B> = Either<(<A as TryFuture>::Ok, B), (<B as TryFuture>::Ok, A)>;
type EitherErr<A, B> = Either<(<A as TryFuture>::Error, B), (<B as TryFuture>::Error, A)>;

/// Waits for either one of two differently-typed futures to complete.
///
/// This function will return a new future which awaits for either one of both
Expand Down Expand Up @@ -52,18 +55,17 @@ where
A: TryFuture + Unpin,
B: TryFuture + Unpin,
{
super::assert_future::<
Result<Either<(A::Ok, B), (B::Ok, A)>, Either<(A::Error, B), (B::Error, A)>>,
_,
>(TrySelect { inner: Some((future1, future2)) })
super::assert_future::<Result<EitherOk<A, B>, EitherErr<A, B>>, _>(TrySelect {
inner: Some((future1, future2)),
})
}

impl<A: Unpin, B: Unpin> Future for TrySelect<A, B>
where
A: TryFuture,
B: TryFuture,
{
type Output = Result<Either<(A::Ok, B), (B::Ok, A)>, Either<(A::Error, B), (B::Error, A)>>;
type Output = Result<EitherOk<A, B>, EitherErr<A, B>>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let (mut a, mut b) = self.inner.take().expect("cannot poll Select twice");
Expand Down
7 changes: 5 additions & 2 deletions futures-util/src/stream/mod.rs
Expand Up @@ -39,7 +39,10 @@ pub use self::stream::Forward;

#[cfg(not(futures_no_atomic_cas))]
#[cfg(feature = "alloc")]
pub use self::stream::{BufferUnordered, Buffered, ForEachConcurrent, TryForEachConcurrent};
pub use self::stream::{
BufferUnordered, Buffered, FlatMapUnordered, FlattenUnordered, ForEachConcurrent,
TryForEachConcurrent,
};

#[cfg(not(futures_no_atomic_cas))]
#[cfg(feature = "sink")]
Expand All @@ -61,7 +64,7 @@ pub use self::try_stream::IntoAsyncRead;

#[cfg(not(futures_no_atomic_cas))]
#[cfg(feature = "alloc")]
pub use self::try_stream::{TryBufferUnordered, TryBuffered};
pub use self::try_stream::{TryBufferUnordered, TryBuffered, TryFlattenUnordered};

#[cfg(feature = "sink")]
#[cfg_attr(docsrs, doc(cfg(feature = "sink")))]
Expand Down
131 changes: 57 additions & 74 deletions futures-util/src/stream/stream/flatten_unordered.rs
@@ -1,4 +1,4 @@
use alloc::sync::Arc;
use alloc::{boxed::Box, sync::Arc};
use core::{
cell::UnsafeCell,
convert::identity,
Expand All @@ -22,8 +22,7 @@ use futures_task::{waker, ArcWake};

use crate::stream::FuturesUnordered;

/// There is nothing to poll and stream isn't being
/// polled or waking at the moment.
/// There is nothing to poll and stream isn't being polled/waking/woken at the moment.
const NONE: u8 = 0;

/// Inner streams need to be polled.
Expand All @@ -32,26 +31,19 @@ const NEED_TO_POLL_INNER_STREAMS: u8 = 1;
/// The base stream needs to be polled.
const NEED_TO_POLL_STREAM: u8 = 0b10;

/// It needs to poll base stream and inner streams.
/// Both base stream and inner streams need to be polled.
const NEED_TO_POLL_ALL: u8 = NEED_TO_POLL_INNER_STREAMS | NEED_TO_POLL_STREAM;

/// The current stream is being polled at the moment.
const POLLING: u8 = 0b100;

/// Inner streams are being woken at the moment.
const WAKING_INNER_STREAMS: u8 = 0b1000;

/// The base stream is being woken at the moment.
const WAKING_STREAM: u8 = 0b10000;

/// The base stream and inner streams are being woken at the moment.
const WAKING_ALL: u8 = WAKING_STREAM | WAKING_INNER_STREAMS;
/// Stream is being woken at the moment.
const WAKING: u8 = 0b1000;

/// The stream was waked and will be polled.
const WOKEN: u8 = 0b100000;
const WOKEN: u8 = 0b10000;

/// Determines what needs to be polled, and is stream being polled at the
/// moment or not.
/// Internal polling state of the stream.
#[derive(Clone, Debug)]
struct SharedPollState {
state: Arc<AtomicU8>,
Expand All @@ -71,7 +63,7 @@ impl SharedPollState {
let value = self
.state
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
if value & WAKING_ALL == NONE {
if value & WAKING == NONE {
Some(POLLING)
} else {
None
Expand All @@ -83,23 +75,20 @@ impl SharedPollState {
Some((value, bomb))
}

/// Starts the waking process and performs bitwise or with the given value.
/// Attempts to start the waking process and performs bitwise or with the given value.
///
/// If some waker is already in progress or stream is already woken/being polled, waking process won't start, however
/// state will be disjuncted with the given value.
fn start_waking(
&self,
to_poll: u8,
waking: u8,
) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> {
let value = self
.state
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
// Waking process for this waker already started
if value & waking != NONE {
return None;
}
let mut next_value = value | to_poll;
// Only start the waking process if we're not in the polling phase and the stream isn't woken already
if value & (WOKEN | POLLING) == NONE {
next_value |= waking;
next_value |= WAKING;
}

if next_value != value {
Expand All @@ -110,8 +99,9 @@ impl SharedPollState {
})
.ok()?;

if value & (WOKEN | POLLING) == NONE {
let bomb = PollStateBomb::new(self, move |state| state.stop_waking(waking));
// Only start the waking process if we're not in the polling phase and the stream isn't woken already
if value & (WOKEN | POLLING | WAKING) == NONE {
let bomb = PollStateBomb::new(self, SharedPollState::stop_waking);

Some((value, bomb))
} else {
Expand All @@ -123,7 +113,7 @@ impl SharedPollState {
/// - `!POLLING` allowing to use wakers
/// - `WOKEN` if the state was changed during `POLLING` phase as waker will be called,
/// or `will_be_woken` flag supplied
/// - `!WAKING_ALL` as
/// - `!WAKING` as
/// * Wakers called during the `POLLING` phase won't propagate their calls
/// * `POLLING` phase can't start if some of the wakers are active
/// So no wrapped waker can touch the inner waker's cell, it's safe to poll again.
Expand All @@ -138,20 +128,16 @@ impl SharedPollState {
}
next_value |= value;

Some(next_value & !POLLING & !WAKING_ALL)
Some(next_value & !POLLING & !WAKING)
})
.unwrap()
}

/// Toggles state to non-waking, allowing to start polling.
fn stop_waking(&self, waking: u8) -> u8 {
fn stop_waking(&self) -> u8 {
self.state
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
let mut next_value = value & !waking;
// Waker will be called only if the current waking state is the same as the specified waker state
if value & WAKING_ALL == waking {
next_value |= WOKEN;
}
let next_value = value & !WAKING;

if next_value != value {
Some(next_value)
Expand Down Expand Up @@ -201,24 +187,24 @@ impl<F: FnOnce(&SharedPollState) -> u8> Drop for PollStateBomb<'_, F> {

/// Will update state with the provided value on `wake_by_ref` call
/// and then, if there is a need, call `inner_waker`.
struct InnerWaker {
struct WrappedWaker {
inner_waker: UnsafeCell<Option<Waker>>,
poll_state: SharedPollState,
need_to_poll: u8,
}

unsafe impl Send for InnerWaker {}
unsafe impl Sync for InnerWaker {}
unsafe impl Send for WrappedWaker {}
unsafe impl Sync for WrappedWaker {}

impl InnerWaker {
impl WrappedWaker {
/// Replaces given waker's inner_waker for polling stream/futures which will
/// update poll state on `wake_by_ref` call. Use only if you need several
/// contexts.
///
/// ## Safety
///
/// This function will modify waker's `inner_waker` via `UnsafeCell`, so
/// it should be used only during `POLLING` phase.
/// it should be used only during `POLLING` phase by one thread at the time.
unsafe fn replace_waker(self_arc: &mut Arc<Self>, cx: &Context<'_>) -> Waker {
*self_arc.inner_waker.get() = cx.waker().clone().into();
waker(self_arc.clone())
Expand All @@ -227,16 +213,11 @@ impl InnerWaker {
/// Attempts to start the waking process for the waker with the given value.
/// If succeeded, then the stream isn't yet woken and not being polled at the moment.
fn start_waking(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> {
self.poll_state.start_waking(self.need_to_poll, self.waking_state())
}

/// Returns the corresponding waking state toggled by this waker.
fn waking_state(&self) -> u8 {
self.need_to_poll << 3
self.poll_state.start_waking(self.need_to_poll)
}
}

impl ArcWake for InnerWaker {
impl ArcWake for WrappedWaker {
fn wake_by_ref(self_arc: &Arc<Self>) {
if let Some((_, state_bomb)) = self_arc.start_waking() {
// Safety: now state is not `POLLING`
Expand All @@ -246,12 +227,8 @@ impl ArcWake for InnerWaker {
// Stop waking to allow polling stream
let poll_state_value = state_bomb.fire().unwrap();

// Here we want to call waker only if stream isn't woken yet and
// also to optimize the case when two wakers are called at the same time.
//
// In this case the best strategy will be to propagate only the latest waker's awake,
// and then poll both entities in a single `poll_next` call
if poll_state_value & (WOKEN | WAKING_ALL) == self_arc.waking_state() {
// We want to call waker only if the stream isn't woken yet
if poll_state_value & (WOKEN | WAKING) == WAKING {
// Wake up inner waker
inner_waker.wake();
}
Expand Down Expand Up @@ -308,14 +285,14 @@ pin_project! {
#[must_use = "streams do nothing unless polled"]
pub struct FlattenUnordered<St> where St: Stream {
#[pin]
inner_streams: FuturesUnordered<PollStreamFut<St::Item>>,
inner_streams: FuturesUnordered<PollStreamFut<Pin<Box<St::Item>>>>,
#[pin]
stream: St,
poll_state: SharedPollState,
limit: Option<NonZeroUsize>,
is_stream_done: bool,
inner_streams_waker: Arc<InnerWaker>,
stream_waker: Arc<InnerWaker>,
inner_streams_waker: Arc<WrappedWaker>,
stream_waker: Arc<WrappedWaker>,
}
}

Expand All @@ -338,7 +315,7 @@ where
impl<St> FlattenUnordered<St>
where
St: Stream,
St::Item: Stream + Unpin,
St::Item: Stream,
{
pub(super) fn new(stream: St, limit: Option<usize>) -> FlattenUnordered<St> {
let poll_state = SharedPollState::new(NEED_TO_POLL_STREAM);
Expand All @@ -348,12 +325,12 @@ where
stream,
is_stream_done: false,
limit: limit.and_then(NonZeroUsize::new),
inner_streams_waker: Arc::new(InnerWaker {
inner_streams_waker: Arc::new(WrappedWaker {
inner_waker: UnsafeCell::new(None),
poll_state: poll_state.clone(),
need_to_poll: NEED_TO_POLL_INNER_STREAMS,
}),
stream_waker: Arc::new(InnerWaker {
stream_waker: Arc::new(WrappedWaker {
inner_waker: UnsafeCell::new(None),
poll_state: poll_state.clone(),
need_to_poll: NEED_TO_POLL_STREAM,
Expand All @@ -369,7 +346,7 @@ impl<St> FlattenUnorderedProj<'_, St>
where
St: Stream,
{
/// Checks if current `inner_streams` size is less than optional limit.
/// Checks if current `inner_streams` size is greater than optional limit.
fn is_exceeded_limit(&self) -> bool {
self.limit.map_or(false, |limit| self.inner_streams.len() >= limit.get())
}
Expand All @@ -378,7 +355,7 @@ where
impl<St> FusedStream for FlattenUnordered<St>
where
St: FusedStream,
St::Item: FusedStream + Unpin,
St::Item: Stream,
{
fn is_terminated(&self) -> bool {
self.stream.is_terminated() && self.inner_streams.is_empty()
Expand All @@ -388,7 +365,7 @@ where
impl<St> Stream for FlattenUnordered<St>
where
St: Stream,
St::Item: Stream + Unpin,
St::Item: Stream,
{
type Item = <St::Item as Stream>::Item;

Expand All @@ -407,8 +384,7 @@ where
};

if poll_state_value & NEED_TO_POLL_STREAM != NONE {
// Safety: now state is `POLLING`.
let stream_waker = unsafe { InnerWaker::replace_waker(this.stream_waker, cx) };
let mut stream_waker = None;

// Here we need to poll the base stream.
//
Expand All @@ -424,15 +400,24 @@ where

break;
} else {
match this.stream.as_mut().poll_next(&mut Context::from_waker(&stream_waker)) {
// Initialize base stream waker if it's not yet initialized
if stream_waker.is_none() {
// Safety: now state is `POLLING`.
stream_waker
.replace(unsafe { WrappedWaker::replace_waker(this.stream_waker, cx) });
}
let mut cx = Context::from_waker(stream_waker.as_ref().unwrap());

match this.stream.as_mut().poll_next(&mut cx) {
Poll::Ready(Some(inner_stream)) => {
let next_item_fut = PollStreamFut::new(Box::pin(inner_stream));
taiki-e marked this conversation as resolved.
Show resolved Hide resolved
// Add new stream to the inner streams bucket
this.inner_streams.as_mut().push(PollStreamFut::new(inner_stream));
this.inner_streams.as_mut().push(next_item_fut);
// Inner streams must be polled afterward
poll_state_value |= NEED_TO_POLL_INNER_STREAMS;
}
Poll::Ready(None) => {
// Mark the stream as done
// Mark the base stream as done
*this.is_stream_done = true;
}
Poll::Pending => {
Expand All @@ -446,13 +431,10 @@ where
if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE {
// Safety: now state is `POLLING`.
let inner_streams_waker =
unsafe { InnerWaker::replace_waker(this.inner_streams_waker, cx) };
unsafe { WrappedWaker::replace_waker(this.inner_streams_waker, cx) };
let mut cx = Context::from_waker(&inner_streams_waker);

match this
.inner_streams
.as_mut()
.poll_next(&mut Context::from_waker(&inner_streams_waker))
{
match this.inner_streams.as_mut().poll_next(&mut cx) {
Poll::Ready(Some(Some((item, next_item_fut)))) => {
// Push next inner stream item future to the list of inner streams futures
this.inner_streams.as_mut().push(next_item_fut);
Expand All @@ -472,15 +454,16 @@ where
// We didn't have any `poll_next` panic, so it's time to deactivate the bomb
state_bomb.deactivate();

// Call the waker at the end of polling if
let mut force_wake =
// we need to poll the stream and didn't reach the limit yet
need_to_poll_next & NEED_TO_POLL_STREAM != NONE && !this.is_exceeded_limit()
// or we need to poll inner streams again
// or we need to poll the inner streams again
|| need_to_poll_next & NEED_TO_POLL_INNER_STREAMS != NONE;

// Stop polling and swap the latest state
poll_state_value = this.poll_state.stop_polling(need_to_poll_next, force_wake);
// If state was changed during `POLLING` phase, need to manually call a waker
// If state was changed during `POLLING` phase, we also need to manually call a waker
force_wake |= poll_state_value & NEED_TO_POLL_ALL != NONE;

let is_done = *this.is_stream_done && this.inner_streams.is_empty();
Expand Down