Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JoinSet #4335

Merged
merged 21 commits into from Feb 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion tokio/Cargo.toml
Expand Up @@ -138,7 +138,7 @@ socket2 = "0.4"
mio-aio = { version = "0.6.0", features = ["tokio"] }

[target.'cfg(loom)'.dev-dependencies]
loom = { version = "0.5", features = ["futures", "checkpoint"] }
loom = { version = "0.5.2", features = ["futures", "checkpoint"] }

[package.metadata.docs.rs]
all-features = true
Expand Down
7 changes: 3 additions & 4 deletions tokio/src/runtime/basic_scheduler.rs
@@ -1,6 +1,6 @@
use crate::future::poll_fn;
use crate::loom::sync::atomic::AtomicBool;
use crate::loom::sync::Mutex;
use crate::loom::sync::{Arc, Mutex};
use crate::park::{Park, Unpark};
use crate::runtime::context::EnterGuard;
use crate::runtime::driver::Driver;
Expand All @@ -16,7 +16,6 @@ use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
use std::sync::atomic::Ordering::{AcqRel, Release};
use std::sync::Arc;
use std::task::Poll::{Pending, Ready};
use std::time::Duration;

Expand Down Expand Up @@ -481,8 +480,8 @@ impl Schedule for Arc<Shared> {
}

impl Wake for Shared {
fn wake(self: Arc<Self>) {
Wake::wake_by_ref(&self)
fn wake(arc_self: Arc<Self>) {
Wake::wake_by_ref(&arc_self)
}

/// Wake by reference
Expand Down
7 changes: 7 additions & 0 deletions tokio/src/runtime/task/harness.rs
Expand Up @@ -164,6 +164,13 @@ where
}
}

/// Try to set the waker notified when the task is complete. Returns true if
/// the task has already completed. If this call returns false, then the
/// waker will not be notified.
pub(super) fn try_set_join_waker(self, waker: &Waker) -> bool {
can_read_output(self.header(), self.trailer(), waker)
}

pub(super) fn drop_join_handle_slow(self) {
let mut maybe_panic = None;

Expand Down
12 changes: 11 additions & 1 deletion tokio/src/runtime/task/join.rs
Expand Up @@ -5,7 +5,7 @@ use std::future::Future;
use std::marker::PhantomData;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::task::{Context, Poll, Waker};

cfg_rt! {
/// An owned permission to join on a task (await its termination).
Expand Down Expand Up @@ -200,6 +200,16 @@ impl<T> JoinHandle<T> {
raw.remote_abort();
}
}

/// Set the waker that is notified when the task completes.
pub(crate) fn set_join_waker(&mut self, waker: &Waker) {
if let Some(raw) = self.raw {
if raw.try_set_join_waker(waker) {
// In this case the task has already completed. We wake the waker immediately.
waker.wake_by_ref();
}
}
}
}

impl<T> Unpin for JoinHandle<T> {}
Expand Down
16 changes: 16 additions & 0 deletions tokio/src/runtime/task/raw.rs
Expand Up @@ -19,6 +19,11 @@ pub(super) struct Vtable {
/// Reads the task output, if complete.
pub(super) try_read_output: unsafe fn(NonNull<Header>, *mut (), &Waker),

/// Try to set the waker notified when the task is complete. Returns true if
/// the task has already completed. If this call returns false, then the
/// waker will not be notified.
pub(super) try_set_join_waker: unsafe fn(NonNull<Header>, &Waker) -> bool,

/// The join handle has been dropped.
pub(super) drop_join_handle_slow: unsafe fn(NonNull<Header>),

Expand All @@ -35,6 +40,7 @@ pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable {
poll: poll::<T, S>,
dealloc: dealloc::<T, S>,
try_read_output: try_read_output::<T, S>,
try_set_join_waker: try_set_join_waker::<T, S>,
drop_join_handle_slow: drop_join_handle_slow::<T, S>,
remote_abort: remote_abort::<T, S>,
shutdown: shutdown::<T, S>,
Expand Down Expand Up @@ -84,6 +90,11 @@ impl RawTask {
(vtable.try_read_output)(self.ptr, dst, waker);
}

pub(super) fn try_set_join_waker(self, waker: &Waker) -> bool {
let vtable = self.header().vtable;
unsafe { (vtable.try_set_join_waker)(self.ptr, waker) }
}

pub(super) fn drop_join_handle_slow(self) {
let vtable = self.header().vtable;
unsafe { (vtable.drop_join_handle_slow)(self.ptr) }
Expand Down Expand Up @@ -129,6 +140,11 @@ unsafe fn try_read_output<T: Future, S: Schedule>(
harness.try_read_output(out, waker);
}

unsafe fn try_set_join_waker<T: Future, S: Schedule>(ptr: NonNull<Header>, waker: &Waker) -> bool {
let harness = Harness::<T, S>::from_raw(ptr);
harness.try_set_join_waker(waker)
}

unsafe fn drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>) {
let harness = Harness::<T, S>::from_raw(ptr);
harness.drop_join_handle_slow()
Expand Down
82 changes: 82 additions & 0 deletions tokio/src/runtime/tests/loom_join_set.rs
@@ -0,0 +1,82 @@
use crate::runtime::Builder;
use crate::task::JoinSet;

#[test]
fn test_join_set() {
loom::model(|| {
let rt = Builder::new_multi_thread()
.worker_threads(1)
.build()
.unwrap();
let mut set = JoinSet::new();

rt.block_on(async {
assert_eq!(set.len(), 0);
set.spawn(async { () });
assert_eq!(set.len(), 1);
set.spawn(async { () });
assert_eq!(set.len(), 2);
let () = set.join_one().await.unwrap().unwrap();
assert_eq!(set.len(), 1);
set.spawn(async { () });
assert_eq!(set.len(), 2);
let () = set.join_one().await.unwrap().unwrap();
assert_eq!(set.len(), 1);
let () = set.join_one().await.unwrap().unwrap();
assert_eq!(set.len(), 0);
set.spawn(async { () });
assert_eq!(set.len(), 1);
});

drop(set);
drop(rt);
});
}

#[test]
fn abort_all_during_completion() {
use std::sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
};

// These booleans assert that at least one execution had the task complete first, and that at
// least one execution had the task be cancelled before it completed.
let complete_happened = Arc::new(AtomicBool::new(false));
let cancel_happened = Arc::new(AtomicBool::new(false));

{
let complete_happened = complete_happened.clone();
let cancel_happened = cancel_happened.clone();
loom::model(move || {
let rt = Builder::new_multi_thread()
.worker_threads(1)
.build()
.unwrap();

let mut set = JoinSet::new();

rt.block_on(async {
set.spawn(async { () });
set.abort_all();

match set.join_one().await {
Ok(Some(())) => complete_happened.store(true, SeqCst),
Err(err) if err.is_cancelled() => cancel_happened.store(true, SeqCst),
Err(err) => panic!("fail: {}", err),
Ok(None) => {
unreachable!("Aborting the task does not remove it from the JoinSet.")
}
}

assert!(matches!(set.join_one().await, Ok(None)));
});

drop(set);
drop(rt);
});
}

assert!(complete_happened.load(SeqCst));
assert!(cancel_happened.load(SeqCst));
}
3 changes: 2 additions & 1 deletion tokio/src/runtime/tests/mod.rs
Expand Up @@ -30,12 +30,13 @@ mod unowned_wrapper {

cfg_loom! {
mod loom_basic_scheduler;
mod loom_local;
mod loom_blocking;
mod loom_local;
mod loom_oneshot;
mod loom_pool;
mod loom_queue;
mod loom_shutdown_join;
mod loom_join_set;
}

cfg_not_loom! {
Expand Down