From 77ba20bd22584e008d9e6163f8bc0addee0c9714 Mon Sep 17 00:00:00 2001 From: Oleg Nosov Date: Sat, 16 Apr 2022 20:03:47 +0400 Subject: [PATCH] Add `TryFlattenUnordered` and improve `FlattenUnordered` (#2577) --- futures-util/benches/flatten_unordered.rs | 7 +- futures-util/src/future/try_select.rs | 13 +- futures-util/src/stream/mod.rs | 15 +- .../src/stream/stream/flatten_unordered.rs | 123 +++++++--------- futures-util/src/stream/stream/mod.rs | 15 +- futures-util/src/stream/try_stream/mod.rs | 65 +++++++++ .../src/stream/try_stream/try_chunks.rs | 5 +- .../try_stream/try_flatten_unordered.rs | 133 ++++++++++++++++++ futures/tests/stream_try_stream.rs | 41 ++++++ 9 files changed, 326 insertions(+), 91 deletions(-) create mode 100644 futures-util/src/stream/try_stream/try_flatten_unordered.rs diff --git a/futures-util/benches/flatten_unordered.rs b/futures-util/benches/flatten_unordered.rs index 64d5f9a4e3..b92f614914 100644 --- a/futures-util/benches/flatten_unordered.rs +++ b/futures-util/benches/flatten_unordered.rs @@ -5,7 +5,7 @@ use crate::test::Bencher; use futures::channel::oneshot; use futures::executor::block_on; -use futures::future::{self, FutureExt}; +use futures::future; use futures::stream::{self, StreamExt}; use futures::task::Poll; use std::collections::VecDeque; @@ -35,15 +35,14 @@ fn oneshot_streams(b: &mut Bencher) { }); let mut flatten = stream::unfold(rxs.into_iter(), |mut vals| { - async { + Box::pin(async { if let Some(next) = vals.next() { let val = next.await.unwrap(); Some((val, vals)) } else { None } - } - .boxed() + }) }) .flatten_unordered(None); diff --git a/futures-util/src/future/try_select.rs b/futures-util/src/future/try_select.rs index 4d0b7ff135..bc282f7db1 100644 --- a/futures-util/src/future/try_select.rs +++ b/futures-util/src/future/try_select.rs @@ -12,6 +12,9 @@ pub struct TrySelect { impl Unpin for TrySelect {} +type EitherOk = Either<(::Ok, B), (::Ok, A)>; +type EitherErr = Either<(::Error, B), (::Error, A)>; + /// Waits for either one of two differently-typed futures to complete. /// /// This function will return a new future which awaits for either one of both @@ -52,10 +55,9 @@ where A: TryFuture + Unpin, B: TryFuture + Unpin, { - super::assert_future::< - Result, Either<(A::Error, B), (B::Error, A)>>, - _, - >(TrySelect { inner: Some((future1, future2)) }) + super::assert_future::, EitherErr>, _>(TrySelect { + inner: Some((future1, future2)), + }) } impl Future for TrySelect @@ -63,8 +65,7 @@ where A: TryFuture, B: TryFuture, { - #[allow(clippy::type_complexity)] - type Output = Result, Either<(A::Error, B), (B::Error, A)>>; + type Output = Result, EitherErr>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let (mut a, mut b) = self.inner.take().expect("cannot poll Select twice"); diff --git a/futures-util/src/stream/mod.rs b/futures-util/src/stream/mod.rs index ec685b9848..bf9506147c 100644 --- a/futures-util/src/stream/mod.rs +++ b/futures-util/src/stream/mod.rs @@ -18,9 +18,10 @@ pub use futures_core::stream::{FusedStream, Stream, TryStream}; #[allow(clippy::module_inception)] mod stream; pub use self::stream::{ - Chain, Collect, Concat, Cycle, Enumerate, Filter, FilterMap, FlatMap, Flatten, Fold, ForEach, - Fuse, Inspect, Map, Next, NextIf, NextIfEq, Peek, PeekMut, Peekable, Scan, SelectNextSome, - Skip, SkipWhile, StreamExt, StreamFuture, Take, TakeUntil, TakeWhile, Then, Unzip, Zip, + All, Any, Chain, Collect, Concat, Count, Cycle, Enumerate, Filter, FilterMap, FlatMap, Flatten, + Fold, ForEach, Fuse, Inspect, Map, Next, NextIf, NextIfEq, Peek, PeekMut, Peekable, Scan, + SelectNextSome, Skip, SkipWhile, StreamExt, StreamFuture, Take, TakeUntil, TakeWhile, Then, + Unzip, Zip, }; #[cfg(feature = "std")] @@ -38,7 +39,9 @@ pub use self::stream::Forward; #[cfg(not(futures_no_atomic_cas))] #[cfg(feature = "alloc")] -pub use self::stream::{BufferUnordered, Buffered, ForEachConcurrent}; +pub use self::stream::{ + BufferUnordered, Buffered, FlatMapUnordered, FlattenUnordered, ForEachConcurrent, +}; #[cfg(not(futures_no_atomic_cas))] #[cfg(feature = "sink")] @@ -60,7 +63,9 @@ pub use self::try_stream::IntoAsyncRead; #[cfg(not(futures_no_atomic_cas))] #[cfg(feature = "alloc")] -pub use self::try_stream::{TryBufferUnordered, TryBuffered, TryForEachConcurrent}; +pub use self::try_stream::{ + TryBufferUnordered, TryBuffered, TryFlattenUnordered, TryForEachConcurrent, +}; #[cfg(feature = "alloc")] pub use self::try_stream::{TryChunks, TryChunksError}; diff --git a/futures-util/src/stream/stream/flatten_unordered.rs b/futures-util/src/stream/stream/flatten_unordered.rs index 07f971c55a..66ba4d0d55 100644 --- a/futures-util/src/stream/stream/flatten_unordered.rs +++ b/futures-util/src/stream/stream/flatten_unordered.rs @@ -22,8 +22,7 @@ use futures_task::{waker, ArcWake}; use crate::stream::FuturesUnordered; -/// There is nothing to poll and stream isn't being -/// polled or waking at the moment. +/// There is nothing to poll and stream isn't being polled/waking/woken at the moment. const NONE: u8 = 0; /// Inner streams need to be polled. @@ -32,26 +31,19 @@ const NEED_TO_POLL_INNER_STREAMS: u8 = 1; /// The base stream needs to be polled. const NEED_TO_POLL_STREAM: u8 = 0b10; -/// It needs to poll base stream and inner streams. +/// Both base stream and inner streams need to be polled. const NEED_TO_POLL_ALL: u8 = NEED_TO_POLL_INNER_STREAMS | NEED_TO_POLL_STREAM; /// The current stream is being polled at the moment. const POLLING: u8 = 0b100; -/// Inner streams are being woken at the moment. -const WAKING_INNER_STREAMS: u8 = 0b1000; - -/// The base stream is being woken at the moment. -const WAKING_STREAM: u8 = 0b10000; - -/// The base stream and inner streams are being woken at the moment. -const WAKING_ALL: u8 = WAKING_STREAM | WAKING_INNER_STREAMS; +/// Stream is being woken at the moment. +const WAKING: u8 = 0b1000; /// The stream was waked and will be polled. -const WOKEN: u8 = 0b100000; +const WOKEN: u8 = 0b10000; -/// Determines what needs to be polled, and is stream being polled at the -/// moment or not. +/// Internal polling state of the stream. #[derive(Clone, Debug)] struct SharedPollState { state: Arc, @@ -71,7 +63,7 @@ impl SharedPollState { let value = self .state .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| { - if value & WAKING_ALL == NONE { + if value & WAKING == NONE { Some(POLLING) } else { None @@ -83,23 +75,20 @@ impl SharedPollState { Some((value, bomb)) } - /// Starts the waking process and performs bitwise or with the given value. + /// Attempts to start the waking process and performs bitwise or with the given value. + /// + /// If some waker is already in progress or stream is already woken/being polled, waking process won't start, however + /// state will be disjuncted with the given value. fn start_waking( &self, to_poll: u8, - waking: u8, ) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> { let value = self .state .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| { - // Waking process for this waker already started - if value & waking != NONE { - return None; - } let mut next_value = value | to_poll; - // Only start the waking process if we're not in the polling phase and the stream isn't woken already if value & (WOKEN | POLLING) == NONE { - next_value |= waking; + next_value |= WAKING; } if next_value != value { @@ -110,8 +99,9 @@ impl SharedPollState { }) .ok()?; - if value & (WOKEN | POLLING) == NONE { - let bomb = PollStateBomb::new(self, move |state| state.stop_waking(waking)); + // Only start the waking process if we're not in the polling phase and the stream isn't woken already + if value & (WOKEN | POLLING | WAKING) == NONE { + let bomb = PollStateBomb::new(self, SharedPollState::stop_waking); Some((value, bomb)) } else { @@ -123,7 +113,7 @@ impl SharedPollState { /// - `!POLLING` allowing to use wakers /// - `WOKEN` if the state was changed during `POLLING` phase as waker will be called, /// or `will_be_woken` flag supplied - /// - `!WAKING_ALL` as + /// - `!WAKING` as /// * Wakers called during the `POLLING` phase won't propagate their calls /// * `POLLING` phase can't start if some of the wakers are active /// So no wrapped waker can touch the inner waker's cell, it's safe to poll again. @@ -138,20 +128,16 @@ impl SharedPollState { } next_value |= value; - Some(next_value & !POLLING & !WAKING_ALL) + Some(next_value & !POLLING & !WAKING) }) .unwrap() } /// Toggles state to non-waking, allowing to start polling. - fn stop_waking(&self, waking: u8) -> u8 { + fn stop_waking(&self) -> u8 { self.state .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| { - let mut next_value = value & !waking; - // Waker will be called only if the current waking state is the same as the specified waker state - if value & WAKING_ALL == waking { - next_value |= WOKEN; - } + let next_value = value & !WAKING; if next_value != value { Some(next_value) @@ -201,16 +187,16 @@ impl u8> Drop for PollStateBomb<'_, F> { /// Will update state with the provided value on `wake_by_ref` call /// and then, if there is a need, call `inner_waker`. -struct InnerWaker { +struct WrappedWaker { inner_waker: UnsafeCell>, poll_state: SharedPollState, need_to_poll: u8, } -unsafe impl Send for InnerWaker {} -unsafe impl Sync for InnerWaker {} +unsafe impl Send for WrappedWaker {} +unsafe impl Sync for WrappedWaker {} -impl InnerWaker { +impl WrappedWaker { /// Replaces given waker's inner_waker for polling stream/futures which will /// update poll state on `wake_by_ref` call. Use only if you need several /// contexts. @@ -218,7 +204,7 @@ impl InnerWaker { /// ## Safety /// /// This function will modify waker's `inner_waker` via `UnsafeCell`, so - /// it should be used only during `POLLING` phase. + /// it should be used only during `POLLING` phase by one thread at the time. unsafe fn replace_waker(self_arc: &mut Arc, cx: &Context<'_>) -> Waker { *self_arc.inner_waker.get() = cx.waker().clone().into(); waker(self_arc.clone()) @@ -227,16 +213,11 @@ impl InnerWaker { /// Attempts to start the waking process for the waker with the given value. /// If succeeded, then the stream isn't yet woken and not being polled at the moment. fn start_waking(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> { - self.poll_state.start_waking(self.need_to_poll, self.waking_state()) - } - - /// Returns the corresponding waking state toggled by this waker. - fn waking_state(&self) -> u8 { - self.need_to_poll << 3 + self.poll_state.start_waking(self.need_to_poll) } } -impl ArcWake for InnerWaker { +impl ArcWake for WrappedWaker { fn wake_by_ref(self_arc: &Arc) { if let Some((_, state_bomb)) = self_arc.start_waking() { // Safety: now state is not `POLLING` @@ -246,12 +227,8 @@ impl ArcWake for InnerWaker { // Stop waking to allow polling stream let poll_state_value = state_bomb.fire().unwrap(); - // Here we want to call waker only if stream isn't woken yet and - // also to optimize the case when two wakers are called at the same time. - // - // In this case the best strategy will be to propagate only the latest waker's awake, - // and then poll both entities in a single `poll_next` call - if poll_state_value & (WOKEN | WAKING_ALL) == self_arc.waking_state() { + // We want to call waker only if the stream isn't woken yet + if poll_state_value & (WOKEN | WAKING) == WAKING { // Wake up inner waker inner_waker.wake(); } @@ -314,8 +291,8 @@ pin_project! { poll_state: SharedPollState, limit: Option, is_stream_done: bool, - inner_streams_waker: Arc, - stream_waker: Arc, + inner_streams_waker: Arc, + stream_waker: Arc, } } @@ -348,12 +325,12 @@ where stream, is_stream_done: false, limit: limit.and_then(NonZeroUsize::new), - inner_streams_waker: Arc::new(InnerWaker { + inner_streams_waker: Arc::new(WrappedWaker { inner_waker: UnsafeCell::new(None), poll_state: poll_state.clone(), need_to_poll: NEED_TO_POLL_INNER_STREAMS, }), - stream_waker: Arc::new(InnerWaker { + stream_waker: Arc::new(WrappedWaker { inner_waker: UnsafeCell::new(None), poll_state: poll_state.clone(), need_to_poll: NEED_TO_POLL_STREAM, @@ -369,7 +346,7 @@ impl FlattenUnorderedProj<'_, St> where St: Stream, { - /// Checks if current `inner_streams` size is less than optional limit. + /// Checks if current `inner_streams` bucket size is greater than optional limit. fn is_exceeded_limit(&self) -> bool { self.limit.map_or(false, |limit| self.inner_streams.len() >= limit.get()) } @@ -378,7 +355,7 @@ where impl FusedStream for FlattenUnordered where St: FusedStream, - St::Item: FusedStream + Unpin, + St::Item: Stream + Unpin, { fn is_terminated(&self) -> bool { self.stream.is_terminated() && self.inner_streams.is_empty() @@ -407,8 +384,7 @@ where }; if poll_state_value & NEED_TO_POLL_STREAM != NONE { - // Safety: now state is `POLLING`. - let stream_waker = unsafe { InnerWaker::replace_waker(this.stream_waker, cx) }; + let mut stream_waker = None; // Here we need to poll the base stream. // @@ -424,15 +400,24 @@ where break; } else { - match this.stream.as_mut().poll_next(&mut Context::from_waker(&stream_waker)) { + // Initialize base stream waker if it's not yet initialized + if stream_waker.is_none() { + // Safety: now state is `POLLING`. + stream_waker + .replace(unsafe { WrappedWaker::replace_waker(this.stream_waker, cx) }); + } + let mut cx = Context::from_waker(stream_waker.as_ref().unwrap()); + + match this.stream.as_mut().poll_next(&mut cx) { Poll::Ready(Some(inner_stream)) => { + let next_item_fut = PollStreamFut::new(inner_stream); // Add new stream to the inner streams bucket - this.inner_streams.as_mut().push(PollStreamFut::new(inner_stream)); + this.inner_streams.as_mut().push(next_item_fut); // Inner streams must be polled afterward poll_state_value |= NEED_TO_POLL_INNER_STREAMS; } Poll::Ready(None) => { - // Mark the stream as done + // Mark the base stream as done *this.is_stream_done = true; } Poll::Pending => { @@ -446,13 +431,10 @@ where if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE { // Safety: now state is `POLLING`. let inner_streams_waker = - unsafe { InnerWaker::replace_waker(this.inner_streams_waker, cx) }; + unsafe { WrappedWaker::replace_waker(this.inner_streams_waker, cx) }; + let mut cx = Context::from_waker(&inner_streams_waker); - match this - .inner_streams - .as_mut() - .poll_next(&mut Context::from_waker(&inner_streams_waker)) - { + match this.inner_streams.as_mut().poll_next(&mut cx) { Poll::Ready(Some(Some((item, next_item_fut)))) => { // Push next inner stream item future to the list of inner streams futures this.inner_streams.as_mut().push(next_item_fut); @@ -472,15 +454,16 @@ where // We didn't have any `poll_next` panic, so it's time to deactivate the bomb state_bomb.deactivate(); + // Call the waker at the end of polling if let mut force_wake = // we need to poll the stream and didn't reach the limit yet need_to_poll_next & NEED_TO_POLL_STREAM != NONE && !this.is_exceeded_limit() - // or we need to poll inner streams again + // or we need to poll the inner streams again || need_to_poll_next & NEED_TO_POLL_INNER_STREAMS != NONE; // Stop polling and swap the latest state poll_state_value = this.poll_state.stop_polling(need_to_poll_next, force_wake); - // If state was changed during `POLLING` phase, need to manually call a waker + // If state was changed during `POLLING` phase, we also need to manually call a waker force_wake |= poll_state_value & NEED_TO_POLL_ALL != NONE; let is_done = *this.is_stream_done && this.inner_streams.is_empty(); diff --git a/futures-util/src/stream/stream/mod.rs b/futures-util/src/stream/stream/mod.rs index bb5e24907d..eb86cb757d 100644 --- a/futures-util/src/stream/stream/mod.rs +++ b/futures-util/src/stream/stream/mod.rs @@ -774,7 +774,14 @@ pub trait StreamExt: Stream { } /// Flattens a stream of streams into just one continuous stream. Polls - /// inner streams concurrently. + /// inner streams produced by the base stream concurrently. + /// + /// The only argument is an optional limit on the number of concurrently + /// polled streams. If this limit is not `None`, no more than `limit` streams + /// will be polled at the same time. The `limit` argument is of type + /// `Into>`, and so can be provided as either `None`, + /// `Some(10)`, or just `10`. Note: a limit of zero is interpreted as + /// no limit at all, and will have the same result as passing in `None`. /// /// # Examples /// @@ -814,7 +821,7 @@ pub trait StreamExt: Stream { Self::Item: Stream + Unpin, Self: Sized, { - FlattenUnordered::new(self, limit.into()) + assert_stream::<::Item, _>(FlattenUnordered::new(self, limit.into())) } /// Maps a stream like [`StreamExt::map`] but flattens nested `Stream`s. @@ -863,7 +870,7 @@ pub trait StreamExt: Stream { /// /// The first argument is an optional limit on the number of concurrently /// polled streams. If this limit is not `None`, no more than `limit` streams - /// will be polled concurrently. The `limit` argument is of type + /// will be polled at the same time. The `limit` argument is of type /// `Into>`, and so can be provided as either `None`, /// `Some(10)`, or just `10`. Note: a limit of zero is interpreted as /// no limit at all, and will have the same result as passing in `None`. @@ -901,7 +908,7 @@ pub trait StreamExt: Stream { F: FnMut(Self::Item) -> U, Self: Sized, { - FlatMapUnordered::new(self, limit.into(), f) + assert_stream::(FlatMapUnordered::new(self, limit.into(), f)) } /// Combinator similar to [`StreamExt::fold`] that holds internal state diff --git a/futures-util/src/stream/try_stream/mod.rs b/futures-util/src/stream/try_stream/mod.rs index bc4c6e4f6a..42f5e7324b 100644 --- a/futures-util/src/stream/try_stream/mod.rs +++ b/futures-util/src/stream/try_stream/mod.rs @@ -15,6 +15,7 @@ use crate::stream::{Inspect, Map}; #[cfg(feature = "alloc")] use alloc::vec::Vec; use core::pin::Pin; + use futures_core::{ future::{Future, TryFuture}, stream::TryStream, @@ -88,6 +89,14 @@ mod try_flatten; #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 pub use self::try_flatten::TryFlatten; +#[cfg(not(futures_no_atomic_cas))] +#[cfg(feature = "alloc")] +mod try_flatten_unordered; +#[cfg(not(futures_no_atomic_cas))] +#[cfg(feature = "alloc")] +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::try_flatten_unordered::TryFlattenUnordered; + mod try_collect; #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 pub use self::try_collect::TryCollect; @@ -711,6 +720,62 @@ pub trait TryStreamExt: TryStream { assert_stream::, _>(TryFilterMap::new(self, f)) } + /// Flattens a stream of streams into just one continuous stream. Produced streams + /// will be polled concurrently and any errors are passed through without looking at them. + /// + /// The only argument is an optional limit on the number of concurrently + /// polled streams. If this limit is not `None`, no more than `limit` streams + /// will be polled at the same time. The `limit` argument is of type + /// `Into>`, and so can be provided as either `None`, + /// `Some(10)`, or just `10`. Note: a limit of zero is interpreted as + /// no limit at all, and will have the same result as passing in `None`. + /// + /// # Examples + /// + /// ``` + /// # futures::executor::block_on(async { + /// use futures::channel::mpsc; + /// use futures::stream::{StreamExt, TryStreamExt}; + /// use std::thread; + /// + /// let (tx1, rx1) = mpsc::unbounded(); + /// let (tx2, rx2) = mpsc::unbounded(); + /// let (tx3, rx3) = mpsc::unbounded(); + /// + /// thread::spawn(move || { + /// tx1.unbounded_send(Ok(1)).unwrap(); + /// }); + /// thread::spawn(move || { + /// tx2.unbounded_send(Ok(2)).unwrap(); + /// tx2.unbounded_send(Err(3)).unwrap(); + /// tx2.unbounded_send(Ok(4)).unwrap(); + /// }); + /// thread::spawn(move || { + /// tx3.unbounded_send(Ok(rx1)).unwrap(); + /// tx3.unbounded_send(Ok(rx2)).unwrap(); + /// tx3.unbounded_send(Err(5)).unwrap(); + /// }); + /// + /// let stream = rx3.try_flatten_unordered(None); + /// let mut values: Vec<_> = stream.collect().await; + /// values.sort(); + /// + /// assert_eq!(values, vec![Ok(1), Ok(2), Ok(4), Err(3), Err(5)]); + /// # }); + /// ``` + #[cfg(not(futures_no_atomic_cas))] + #[cfg(feature = "alloc")] + fn try_flatten_unordered(self, limit: impl Into>) -> TryFlattenUnordered + where + Self::Ok: TryStream + Unpin, + ::Error: From, + Self: Sized, + { + assert_stream::::Ok, ::Error>, _>( + TryFlattenUnordered::new(self, limit), + ) + } + /// Flattens a stream of streams into just one continuous stream. /// /// If this stream's elements are themselves streams then this combinator diff --git a/futures-util/src/stream/try_stream/try_chunks.rs b/futures-util/src/stream/try_stream/try_chunks.rs index 3bb253a714..ec53f4bd11 100644 --- a/futures-util/src/stream/try_stream/try_chunks.rs +++ b/futures-util/src/stream/try_stream/try_chunks.rs @@ -41,9 +41,10 @@ impl TryChunks { delegate_access_inner!(stream, St, (. .)); } +type TryChunksStreamError = TryChunksError<::Ok, ::Error>; + impl Stream for TryChunks { - #[allow(clippy::type_complexity)] - type Item = Result, TryChunksError>; + type Item = Result, TryChunksStreamError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.as_mut().project(); diff --git a/futures-util/src/stream/try_stream/try_flatten_unordered.rs b/futures-util/src/stream/try_stream/try_flatten_unordered.rs new file mode 100644 index 0000000000..aaad910bf0 --- /dev/null +++ b/futures-util/src/stream/try_stream/try_flatten_unordered.rs @@ -0,0 +1,133 @@ +use core::pin::Pin; + +use futures_core::ready; +use futures_core::stream::{FusedStream, Stream, TryStream}; +use futures_core::task::{Context, Poll}; +#[cfg(feature = "sink")] +use futures_sink::Sink; + +use pin_project_lite::pin_project; + +use crate::future::Either; +use crate::stream::stream::FlattenUnordered; +use crate::StreamExt; + +use super::IntoStream; + +delegate_all!( + /// Stream for the [`try_flatten_unordered`](super::TryStreamExt::try_flatten_unordered) method. + TryFlattenUnordered( + FlattenUnordered> + ): Debug + Sink + Stream + FusedStream + AccessInner[St, (. .)] + + New[ + |stream: St, limit: impl Into>| + TryStreamOfTryStreamsIntoHomogeneousStreamOfTryStreams::new(stream).flatten_unordered(limit) + ] + where + St: TryStream, + St::Ok: TryStream, + St::Ok: Unpin, + ::Error: From +); + +pin_project! { + /// Emits either successful streams or single-item streams containing the underlying errors. + /// This's a wrapper for `FlattenUnordered` to reuse its logic over `TryStream`. + #[derive(Debug)] + #[must_use = "streams do nothing unless polled"] + pub struct TryStreamOfTryStreamsIntoHomogeneousStreamOfTryStreams + where + St: TryStream, + St::Ok: TryStream, + St::Ok: Unpin, + ::Error: From + { + #[pin] + stream: St, + } +} + +impl TryStreamOfTryStreamsIntoHomogeneousStreamOfTryStreams +where + St: TryStream, + St::Ok: TryStream + Unpin, + ::Error: From, +{ + fn new(stream: St) -> Self { + Self { stream } + } + + delegate_access_inner!(stream, St, ()); +} + +impl FusedStream for TryStreamOfTryStreamsIntoHomogeneousStreamOfTryStreams +where + St: TryStream + FusedStream, + St::Ok: TryStream + Unpin, + ::Error: From, +{ + fn is_terminated(&self) -> bool { + self.stream.is_terminated() + } +} + +/// Emits single item immediately, then stream will be terminated. +#[derive(Debug, Clone)] +pub struct Single(Option); + +impl Unpin for Single {} + +impl Stream for Single { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(self.0.take()) + } + + fn size_hint(&self) -> (usize, Option) { + self.0.as_ref().map_or((0, Some(0)), |_| (1, Some(1))) + } +} + +type SingleStreamResult = Single::Ok, ::Error>>; + +impl Stream for TryStreamOfTryStreamsIntoHomogeneousStreamOfTryStreams +where + St: TryStream, + St::Ok: TryStream + Unpin, + ::Error: From, +{ + // Item is either an inner stream or a stream containing a single error. + // This will allow using `Either`'s `Stream` implementation as both branches are actually streams of `Result`'s. + type Item = Either, SingleStreamResult>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let item = ready!(self.project().stream.try_poll_next(cx)); + + let out = item.map(|res| match res { + // Emit successful inner stream as is + Ok(stream) => Either::Left(IntoStream::new(stream)), + // Wrap an error into a stream containing a single item + err @ Err(_) => { + let res = err.map(|_: St::Ok| unreachable!()).map_err(Into::into); + + Either::Right(Single(Some(res))) + } + }); + + Poll::Ready(out) + } +} + +// Forwarding impl of Sink from the underlying stream +#[cfg(feature = "sink")] +impl Sink for TryStreamOfTryStreamsIntoHomogeneousStreamOfTryStreams +where + St: TryStream + Sink, + St::Ok: Stream> + Unpin, + ::Error: From<::Error>, +{ + type Error = >::Error; + + delegate_sink!(stream, Item); +} diff --git a/futures/tests/stream_try_stream.rs b/futures/tests/stream_try_stream.rs index 194e74db74..6d00097970 100644 --- a/futures/tests/stream_try_stream.rs +++ b/futures/tests/stream_try_stream.rs @@ -2,6 +2,7 @@ use futures::{ stream::{self, StreamExt, TryStreamExt}, task::Poll, }; +use futures_executor::block_on; use futures_test::task::noop_context; #[test] @@ -36,3 +37,43 @@ fn try_take_while_after_err() { .boxed(); assert_eq!(Poll::Ready(None), s.poll_next_unpin(cx)); } + +#[test] +fn try_flatten_unordered() { + let s = stream::iter(1..7) + .map(|val: u32| { + if val % 2 == 0 { + Ok(stream::unfold((val, 1), |(val, pow)| async move { + Some((val.pow(pow), (val, pow + 1))) + }) + .take(3) + .map(move |val| if val % 16 != 0 { Ok(val) } else { Err(val) })) + } else { + Err(val) + } + }) + .map_ok(Box::pin) + .try_flatten_unordered(None); + + block_on(async move { + assert_eq!( + // All numbers can be divided by 16 and odds must be `Err` + // For all basic evens we must have powers from 1 to 3 + vec![ + Err(1), + Ok(2), + Err(3), + Ok(4), + Err(5), + Ok(6), + Ok(4), + Err(16), + Ok(36), + Ok(8), + Err(64), + Ok(216) + ], + s.collect::>().await + ) + }) +}