From 23a1ccf24fddb78abf5169053368c8b7fd6733dc Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Thu, 13 Oct 2022 11:27:19 -0700 Subject: [PATCH] task: wake local tasks to the local queue when woken by the same thread (#5095) Motivation Currently, when a task spawned on a `LocalSet` is woken by an I/O driver or time driver running on the same thread as the `LocalSet`, the task is pushed to the `LocalSet`'s locked remote run queue rather than to its unsynchronized local run queue. This is unfortunate, as it negates some of the performance benefits of having an unsynchronized local run queue. Instead, tasks are only woken to the local queue when they are woken by other tasks also running on the local set. This occurs because the local queue is only used when the `CONTEXT` thread-local contains a Context that's the same as the task's `Schedule` instance (an `Arc`)'s Context. When the `LocalSet` is not being polled, the thread-local is unset, and the local run queue cannot be accessed by the `Schedule` implementation for `Arc`. Solution This branch fixes this by moving the local run queue into Shared along with the remote run queue. When an `Arc`'s Schedule impl wakes a task and the `CONTEXT` thread-local is None (indicating we are not currently polling the LocalSet on this thread), we now check if the current thread's `ThreadId` matches that of the thread the `LocalSet` was created on, and push the woken task to the local queue if it was. Moving the local run queue into `Shared` is somewhat unfortunate, as it means we now have a single field on the `Shared` type, which must not be accessed from other threads and must add an unsafe impl `Sync` for `Shared`. However, it's the only viable way to wake to the local queue from the Schedule impl for `Arc`, so I figured it was worth the additional unsafe code. I added a debug assertion to check that the local queue is only accessed from the thread that owns the `LocalSet`. --- tokio/src/task/local.rs | 194 +++++++++++++++++++++++++++++++++++----- 1 file changed, 172 insertions(+), 22 deletions(-) diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index 513671d097f..952ae93ea68 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -1,5 +1,6 @@ //! Runs `!Send` futures on the current thread. use crate::loom::sync::{Arc, Mutex}; +use crate::loom::thread::{self, ThreadId}; use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task}; use crate::sync::AtomicWaker; use crate::util::{RcCell, VecDequeCell}; @@ -228,9 +229,6 @@ struct Context { /// Collection of all active tasks spawned onto this executor. owned: LocalOwnedTasks>, - /// Local run queue sender and receiver. - queue: VecDequeCell>>, - /// State shared between threads. shared: Arc, @@ -241,6 +239,19 @@ struct Context { /// LocalSet state shared between threads. struct Shared { + /// Local run queue sender and receiver. + /// + /// # Safety + /// + /// This field must *only* be accessed from the thread that owns the + /// `LocalSet` (i.e., `Thread::current().id() == owner`). + local_queue: VecDequeCell>>, + + /// The `ThreadId` of the thread that owns the `LocalSet`. + /// + /// Since `LocalSet` is `!Send`, this will never change. + owner: ThreadId, + /// Remote run queue sender. queue: Mutex>>>>, @@ -262,10 +273,21 @@ pin_project! { } #[cfg(any(loom, tokio_no_const_thread_local))] -thread_local!(static CURRENT: RcCell = RcCell::new()); +thread_local!(static CURRENT: LocalData = LocalData { + thread_id: Cell::new(None), + ctx: RcCell::new(), +}); #[cfg(not(any(loom, tokio_no_const_thread_local)))] -thread_local!(static CURRENT: RcCell = const { RcCell::new() }); +thread_local!(static CURRENT: LocalData = const { LocalData { + thread_id: Cell::new(None), + ctx: RcCell::new(), +} }); + +struct LocalData { + thread_id: Cell>, + ctx: RcCell, +} cfg_rt! { /// Spawns a `!Send` future on the local task set. @@ -314,7 +336,7 @@ cfg_rt! { where F: Future + 'static, F::Output: 'static { - match CURRENT.with(|maybe_cx| maybe_cx.get()) { + match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) { None => panic!("`spawn_local` called from outside of a `task::LocalSet`"), Some(cx) => cx.spawn(future, name) } @@ -335,7 +357,7 @@ pub struct LocalEnterGuard(Option>); impl Drop for LocalEnterGuard { fn drop(&mut self) { - CURRENT.with(|ctx| { + CURRENT.with(|LocalData { ctx, .. }| { ctx.set(self.0.take()); }) } @@ -354,8 +376,9 @@ impl LocalSet { tick: Cell::new(0), context: Rc::new(Context { owned: LocalOwnedTasks::new(), - queue: VecDequeCell::with_capacity(INITIAL_CAPACITY), shared: Arc::new(Shared { + local_queue: VecDequeCell::with_capacity(INITIAL_CAPACITY), + owner: thread::current().id(), queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))), waker: AtomicWaker::new(), #[cfg(tokio_unstable)] @@ -374,7 +397,7 @@ impl LocalSet { /// /// [`spawn_local`]: fn@crate::task::spawn_local pub fn enter(&self) -> LocalEnterGuard { - CURRENT.with(|ctx| { + CURRENT.with(|LocalData { ctx, .. }| { let old = ctx.replace(Some(self.context.clone())); LocalEnterGuard(old) }) @@ -597,9 +620,9 @@ impl LocalSet { .lock() .as_mut() .and_then(|queue| queue.pop_front()) - .or_else(|| self.context.queue.pop_front()) + .or_else(|| self.pop_local()) } else { - self.context.queue.pop_front().or_else(|| { + self.pop_local().or_else(|| { self.context .shared .queue @@ -612,8 +635,17 @@ impl LocalSet { task.map(|task| self.context.owned.assert_owner(task)) } + fn pop_local(&self) -> Option>> { + unsafe { + // Safety: because the `LocalSet` itself is `!Send`, we know we are + // on the same thread if we have access to the `LocalSet`, and can + // therefore access the local run queue. + self.context.shared.local_queue().pop_front() + } + } + fn with(&self, f: impl FnOnce() -> T) -> T { - CURRENT.with(|ctx| { + CURRENT.with(|LocalData { ctx, .. }| { struct Reset<'a> { ctx_ref: &'a RcCell, val: Option>, @@ -639,7 +671,7 @@ impl LocalSet { fn with_if_possible(&self, f: impl FnOnce() -> T) -> T { let mut f = Some(f); - let res = CURRENT.try_with(|ctx| { + let res = CURRENT.try_with(|LocalData { ctx, .. }| { struct Reset<'a> { ctx_ref: &'a RcCell, val: Option>, @@ -782,7 +814,21 @@ impl Drop for LocalSet { // We already called shutdown on all tasks above, so there is no // need to call shutdown. - for task in self.context.queue.take() { + + // Safety: note that this *intentionally* bypasses the unsafe + // `Shared::local_queue()` method. This is in order to avoid the + // debug assertion that we are on the thread that owns the + // `LocalSet`, because on some systems (e.g. at least some macOS + // versions), attempting to get the current thread ID can panic due + // to the thread's local data that stores the thread ID being + // dropped *before* the `LocalSet`. + // + // Despite avoiding the assertion here, it is safe for us to access + // the local queue in `Drop`, because the `LocalSet` itself is + // `!Send`, so we can reasonably guarantee that it will not be + // `Drop`ped from another thread. + let local_queue = self.context.shared.local_queue.take(); + for task in local_queue { drop(task); } @@ -854,15 +900,48 @@ impl Future for RunUntil<'_, T> { } impl Shared { + /// # Safety + /// + /// This is safe to call if and ONLY if we are on the thread that owns this + /// `LocalSet`. + unsafe fn local_queue(&self) -> &VecDequeCell>> { + debug_assert!( + // if we couldn't get the thread ID because we're dropping the local + // data, skip the assertion --- the `Drop` impl is not going to be + // called from another thread, because `LocalSet` is `!Send` + thread_id().map(|id| id == self.owner).unwrap_or(true), + "`LocalSet`'s local run queue must not be accessed by another thread!" + ); + &self.local_queue + } + /// Schedule the provided task on the scheduler. fn schedule(&self, task: task::Notified>) { - CURRENT.with(|maybe_cx| { - match maybe_cx.get() { - Some(cx) if cx.shared.ptr_eq(self) => { - cx.queue.push_back(task); + CURRENT.with(|localdata| { + match localdata.ctx.get() { + Some(cx) if cx.shared.ptr_eq(self) => unsafe { + // Safety: if the current `LocalSet` context points to this + // `LocalSet`, then we are on the thread that owns it. + cx.shared.local_queue().push_back(task); + }, + + // We are on the thread that owns the `LocalSet`, so we can + // wake to the local queue. + _ if localdata.get_or_insert_id() == self.owner => { + unsafe { + // Safety: we just checked that the thread ID matches + // the localset's owner, so this is safe. + self.local_queue().push_back(task); + } + // We still have to wake the `LocalSet`, because it isn't + // currently being polled. + self.waker.wake(); } + + // We are *not* on the thread that owns the `LocalSet`, so we + // have to wake to the remote queue. _ => { - // First check whether the queue is still there (if not, the + // First, check whether the queue is still there (if not, the // LocalSet is dropped). Then push to it if so, and if not, // do nothing. let mut lock = self.queue.lock(); @@ -882,9 +961,13 @@ impl Shared { } } +// This is safe because (and only because) we *pinky pwomise* to never touch the +// local run queue except from the thread that owns the `LocalSet`. +unsafe impl Sync for Shared {} + impl task::Schedule for Arc { fn release(&self, task: &Task) -> Option> { - CURRENT.with(|maybe_cx| match maybe_cx.get() { + CURRENT.with(|LocalData { ctx, .. }| match ctx.get() { None => panic!("scheduler context missing"), Some(cx) => { assert!(cx.shared.ptr_eq(self)); @@ -909,7 +992,7 @@ impl task::Schedule for Arc { // This hook is only called from within the runtime, so // `CURRENT` should match with `&self`, i.e. there is no // opportunity for a nested scheduler to be called. - CURRENT.with(|maybe_cx| match maybe_cx.get() { + CURRENT.with(|LocalData { ctx, .. }| match ctx.get() { Some(cx) if Arc::ptr_eq(self, &cx.shared) => { cx.unhandled_panic.set(true); cx.owned.close_and_shutdown_all(); @@ -922,9 +1005,31 @@ impl task::Schedule for Arc { } } -#[cfg(test)] +impl LocalData { + fn get_or_insert_id(&self) -> ThreadId { + self.thread_id.get().unwrap_or_else(|| { + let id = thread::current().id(); + self.thread_id.set(Some(id)); + id + }) + } +} + +fn thread_id() -> Option { + CURRENT + .try_with(|localdata| localdata.get_or_insert_id()) + .ok() +} + +#[cfg(all(test, not(loom)))] mod tests { use super::*; + + // Does a `LocalSet` running on a current-thread runtime...basically work? + // + // This duplicates a test in `tests/task_local_set.rs`, but because this is + // a lib test, it wil run under Miri, so this is necessary to catch stacked + // borrows violations in the `LocalSet` implementation. #[test] fn local_current_thread_scheduler() { let f = async { @@ -939,4 +1044,49 @@ mod tests { .expect("rt") .block_on(f) } + + // Tests that when a task on a `LocalSet` is woken by an io driver on the + // same thread, the task is woken to the localset's local queue rather than + // its remote queue. + // + // This test has to be defined in the `local.rs` file as a lib test, rather + // than in `tests/`, because it makes assertions about the local set's + // internal state. + #[test] + fn wakes_to_local_queue() { + use super::*; + use crate::sync::Notify; + let rt = crate::runtime::Builder::new_current_thread() + .build() + .expect("rt"); + rt.block_on(async { + let local = LocalSet::new(); + let notify = Arc::new(Notify::new()); + let task = local.spawn_local({ + let notify = notify.clone(); + async move { + notify.notified().await; + } + }); + let mut run_until = Box::pin(local.run_until(async move { + task.await.unwrap(); + })); + + // poll the run until future once + crate::future::poll_fn(|cx| { + let _ = run_until.as_mut().poll(cx); + Poll::Ready(()) + }) + .await; + + notify.notify_one(); + let task = unsafe { local.context.shared.local_queue().pop_front() }; + // TODO(eliza): it would be nice to be able to assert that this is + // the local task. + assert!( + task.is_some(), + "task should have been notified to the LocalSet's local queue" + ); + }) + } }