diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index 513671d097f..952ae93ea68 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -1,5 +1,6 @@ //! Runs `!Send` futures on the current thread. use crate::loom::sync::{Arc, Mutex}; +use crate::loom::thread::{self, ThreadId}; use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task}; use crate::sync::AtomicWaker; use crate::util::{RcCell, VecDequeCell}; @@ -228,9 +229,6 @@ struct Context { /// Collection of all active tasks spawned onto this executor. owned: LocalOwnedTasks>, - /// Local run queue sender and receiver. - queue: VecDequeCell>>, - /// State shared between threads. shared: Arc, @@ -241,6 +239,19 @@ struct Context { /// LocalSet state shared between threads. struct Shared { + /// Local run queue sender and receiver. + /// + /// # Safety + /// + /// This field must *only* be accessed from the thread that owns the + /// `LocalSet` (i.e., `Thread::current().id() == owner`). + local_queue: VecDequeCell>>, + + /// The `ThreadId` of the thread that owns the `LocalSet`. + /// + /// Since `LocalSet` is `!Send`, this will never change. + owner: ThreadId, + /// Remote run queue sender. queue: Mutex>>>>, @@ -262,10 +273,21 @@ pin_project! { } #[cfg(any(loom, tokio_no_const_thread_local))] -thread_local!(static CURRENT: RcCell = RcCell::new()); +thread_local!(static CURRENT: LocalData = LocalData { + thread_id: Cell::new(None), + ctx: RcCell::new(), +}); #[cfg(not(any(loom, tokio_no_const_thread_local)))] -thread_local!(static CURRENT: RcCell = const { RcCell::new() }); +thread_local!(static CURRENT: LocalData = const { LocalData { + thread_id: Cell::new(None), + ctx: RcCell::new(), +} }); + +struct LocalData { + thread_id: Cell>, + ctx: RcCell, +} cfg_rt! { /// Spawns a `!Send` future on the local task set. @@ -314,7 +336,7 @@ cfg_rt! { where F: Future + 'static, F::Output: 'static { - match CURRENT.with(|maybe_cx| maybe_cx.get()) { + match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) { None => panic!("`spawn_local` called from outside of a `task::LocalSet`"), Some(cx) => cx.spawn(future, name) } @@ -335,7 +357,7 @@ pub struct LocalEnterGuard(Option>); impl Drop for LocalEnterGuard { fn drop(&mut self) { - CURRENT.with(|ctx| { + CURRENT.with(|LocalData { ctx, .. }| { ctx.set(self.0.take()); }) } @@ -354,8 +376,9 @@ impl LocalSet { tick: Cell::new(0), context: Rc::new(Context { owned: LocalOwnedTasks::new(), - queue: VecDequeCell::with_capacity(INITIAL_CAPACITY), shared: Arc::new(Shared { + local_queue: VecDequeCell::with_capacity(INITIAL_CAPACITY), + owner: thread::current().id(), queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))), waker: AtomicWaker::new(), #[cfg(tokio_unstable)] @@ -374,7 +397,7 @@ impl LocalSet { /// /// [`spawn_local`]: fn@crate::task::spawn_local pub fn enter(&self) -> LocalEnterGuard { - CURRENT.with(|ctx| { + CURRENT.with(|LocalData { ctx, .. }| { let old = ctx.replace(Some(self.context.clone())); LocalEnterGuard(old) }) @@ -597,9 +620,9 @@ impl LocalSet { .lock() .as_mut() .and_then(|queue| queue.pop_front()) - .or_else(|| self.context.queue.pop_front()) + .or_else(|| self.pop_local()) } else { - self.context.queue.pop_front().or_else(|| { + self.pop_local().or_else(|| { self.context .shared .queue @@ -612,8 +635,17 @@ impl LocalSet { task.map(|task| self.context.owned.assert_owner(task)) } + fn pop_local(&self) -> Option>> { + unsafe { + // Safety: because the `LocalSet` itself is `!Send`, we know we are + // on the same thread if we have access to the `LocalSet`, and can + // therefore access the local run queue. + self.context.shared.local_queue().pop_front() + } + } + fn with(&self, f: impl FnOnce() -> T) -> T { - CURRENT.with(|ctx| { + CURRENT.with(|LocalData { ctx, .. }| { struct Reset<'a> { ctx_ref: &'a RcCell, val: Option>, @@ -639,7 +671,7 @@ impl LocalSet { fn with_if_possible(&self, f: impl FnOnce() -> T) -> T { let mut f = Some(f); - let res = CURRENT.try_with(|ctx| { + let res = CURRENT.try_with(|LocalData { ctx, .. }| { struct Reset<'a> { ctx_ref: &'a RcCell, val: Option>, @@ -782,7 +814,21 @@ impl Drop for LocalSet { // We already called shutdown on all tasks above, so there is no // need to call shutdown. - for task in self.context.queue.take() { + + // Safety: note that this *intentionally* bypasses the unsafe + // `Shared::local_queue()` method. This is in order to avoid the + // debug assertion that we are on the thread that owns the + // `LocalSet`, because on some systems (e.g. at least some macOS + // versions), attempting to get the current thread ID can panic due + // to the thread's local data that stores the thread ID being + // dropped *before* the `LocalSet`. + // + // Despite avoiding the assertion here, it is safe for us to access + // the local queue in `Drop`, because the `LocalSet` itself is + // `!Send`, so we can reasonably guarantee that it will not be + // `Drop`ped from another thread. + let local_queue = self.context.shared.local_queue.take(); + for task in local_queue { drop(task); } @@ -854,15 +900,48 @@ impl Future for RunUntil<'_, T> { } impl Shared { + /// # Safety + /// + /// This is safe to call if and ONLY if we are on the thread that owns this + /// `LocalSet`. + unsafe fn local_queue(&self) -> &VecDequeCell>> { + debug_assert!( + // if we couldn't get the thread ID because we're dropping the local + // data, skip the assertion --- the `Drop` impl is not going to be + // called from another thread, because `LocalSet` is `!Send` + thread_id().map(|id| id == self.owner).unwrap_or(true), + "`LocalSet`'s local run queue must not be accessed by another thread!" + ); + &self.local_queue + } + /// Schedule the provided task on the scheduler. fn schedule(&self, task: task::Notified>) { - CURRENT.with(|maybe_cx| { - match maybe_cx.get() { - Some(cx) if cx.shared.ptr_eq(self) => { - cx.queue.push_back(task); + CURRENT.with(|localdata| { + match localdata.ctx.get() { + Some(cx) if cx.shared.ptr_eq(self) => unsafe { + // Safety: if the current `LocalSet` context points to this + // `LocalSet`, then we are on the thread that owns it. + cx.shared.local_queue().push_back(task); + }, + + // We are on the thread that owns the `LocalSet`, so we can + // wake to the local queue. + _ if localdata.get_or_insert_id() == self.owner => { + unsafe { + // Safety: we just checked that the thread ID matches + // the localset's owner, so this is safe. + self.local_queue().push_back(task); + } + // We still have to wake the `LocalSet`, because it isn't + // currently being polled. + self.waker.wake(); } + + // We are *not* on the thread that owns the `LocalSet`, so we + // have to wake to the remote queue. _ => { - // First check whether the queue is still there (if not, the + // First, check whether the queue is still there (if not, the // LocalSet is dropped). Then push to it if so, and if not, // do nothing. let mut lock = self.queue.lock(); @@ -882,9 +961,13 @@ impl Shared { } } +// This is safe because (and only because) we *pinky pwomise* to never touch the +// local run queue except from the thread that owns the `LocalSet`. +unsafe impl Sync for Shared {} + impl task::Schedule for Arc { fn release(&self, task: &Task) -> Option> { - CURRENT.with(|maybe_cx| match maybe_cx.get() { + CURRENT.with(|LocalData { ctx, .. }| match ctx.get() { None => panic!("scheduler context missing"), Some(cx) => { assert!(cx.shared.ptr_eq(self)); @@ -909,7 +992,7 @@ impl task::Schedule for Arc { // This hook is only called from within the runtime, so // `CURRENT` should match with `&self`, i.e. there is no // opportunity for a nested scheduler to be called. - CURRENT.with(|maybe_cx| match maybe_cx.get() { + CURRENT.with(|LocalData { ctx, .. }| match ctx.get() { Some(cx) if Arc::ptr_eq(self, &cx.shared) => { cx.unhandled_panic.set(true); cx.owned.close_and_shutdown_all(); @@ -922,9 +1005,31 @@ impl task::Schedule for Arc { } } -#[cfg(test)] +impl LocalData { + fn get_or_insert_id(&self) -> ThreadId { + self.thread_id.get().unwrap_or_else(|| { + let id = thread::current().id(); + self.thread_id.set(Some(id)); + id + }) + } +} + +fn thread_id() -> Option { + CURRENT + .try_with(|localdata| localdata.get_or_insert_id()) + .ok() +} + +#[cfg(all(test, not(loom)))] mod tests { use super::*; + + // Does a `LocalSet` running on a current-thread runtime...basically work? + // + // This duplicates a test in `tests/task_local_set.rs`, but because this is + // a lib test, it wil run under Miri, so this is necessary to catch stacked + // borrows violations in the `LocalSet` implementation. #[test] fn local_current_thread_scheduler() { let f = async { @@ -939,4 +1044,49 @@ mod tests { .expect("rt") .block_on(f) } + + // Tests that when a task on a `LocalSet` is woken by an io driver on the + // same thread, the task is woken to the localset's local queue rather than + // its remote queue. + // + // This test has to be defined in the `local.rs` file as a lib test, rather + // than in `tests/`, because it makes assertions about the local set's + // internal state. + #[test] + fn wakes_to_local_queue() { + use super::*; + use crate::sync::Notify; + let rt = crate::runtime::Builder::new_current_thread() + .build() + .expect("rt"); + rt.block_on(async { + let local = LocalSet::new(); + let notify = Arc::new(Notify::new()); + let task = local.spawn_local({ + let notify = notify.clone(); + async move { + notify.notified().await; + } + }); + let mut run_until = Box::pin(local.run_until(async move { + task.await.unwrap(); + })); + + // poll the run until future once + crate::future::poll_fn(|cx| { + let _ = run_until.as_mut().poll(cx); + Poll::Ready(()) + }) + .await; + + notify.notify_one(); + let task = unsafe { local.context.shared.local_queue().pop_front() }; + // TODO(eliza): it would be nice to be able to assert that this is + // the local task. + assert!( + task.is_some(), + "task should have been notified to the LocalSet's local queue" + ); + }) + } }