diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index e3ba50a79ff..d37607a4b5d 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -10,6 +10,7 @@ use std::fmt; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; +use std::rc::Rc; use std::task::Poll; use pin_project_lite::pin_project; @@ -215,7 +216,7 @@ cfg_rt! { tick: Cell, /// State available from thread-local. - context: Context, + context: Rc, /// This type should not be Send. _not_send: PhantomData<*const ()>, @@ -260,7 +261,7 @@ pin_project! { } } -scoped_thread_local!(static CURRENT: Context); +thread_local!(static CURRENT: Cell>> = Cell::new(None)); cfg_rt! { /// Spawns a `!Send` future on the local task set. @@ -310,10 +311,12 @@ cfg_rt! { F::Output: 'static { CURRENT.with(|maybe_cx| { - let cx = maybe_cx - .expect("`spawn_local` called from outside of a `task::LocalSet`"); + let ctx = clone_rc(maybe_cx); + match ctx { + None => panic!("`spawn_local` called from outside of a `task::LocalSet`"), + Some(cx) => cx.spawn(future, name) + } - cx.spawn(future, name) }) } } @@ -327,12 +330,29 @@ const MAX_TASKS_PER_TICK: usize = 61; /// How often it check the remote queue first. const REMOTE_FIRST_INTERVAL: u8 = 31; +/// Context guard for LocalSet +pub struct LocalEnterGuard(Option>); + +impl Drop for LocalEnterGuard { + fn drop(&mut self) { + CURRENT.with(|ctx| { + ctx.replace(self.0.take()); + }) + } +} + +impl fmt::Debug for LocalEnterGuard { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("LocalEnterGuard").finish() + } +} + impl LocalSet { /// Returns a new local task set. pub fn new() -> LocalSet { LocalSet { tick: Cell::new(0), - context: Context { + context: Rc::new(Context { owned: LocalOwnedTasks::new(), queue: VecDequeCell::with_capacity(INITIAL_CAPACITY), shared: Arc::new(Shared { @@ -342,11 +362,24 @@ impl LocalSet { unhandled_panic: crate::runtime::UnhandledPanic::Ignore, }), unhandled_panic: Cell::new(false), - }, + }), _not_send: PhantomData, } } + /// Enters the context of this `LocalSet`. + /// + /// The [`spawn_local`] method will spawn tasks on the `LocalSet` whose + /// context you are inside. + /// + /// [`spawn_local`]: fn@crate::task::spawn_local + pub fn enter(&self) -> LocalEnterGuard { + CURRENT.with(|ctx| { + let old = ctx.replace(Some(self.context.clone())); + LocalEnterGuard(old) + }) + } + /// Spawns a `!Send` task onto the local task set. /// /// This task is guaranteed to be run on the current thread. @@ -579,7 +612,25 @@ impl LocalSet { } fn with(&self, f: impl FnOnce() -> T) -> T { - CURRENT.set(&self.context, f) + CURRENT.with(|ctx| { + struct Reset<'a> { + ctx_ref: &'a Cell>>, + val: Option>, + } + impl<'a> Drop for Reset<'a> { + fn drop(&mut self) { + self.ctx_ref.replace(self.val.take()); + } + } + let old = ctx.replace(Some(self.context.clone())); + + let _reset = Reset { + ctx_ref: ctx, + val: old, + }; + + f() + }) } } @@ -645,8 +696,9 @@ cfg_unstable! { /// [`JoinHandle`]: struct@crate::task::JoinHandle pub fn unhandled_panic(&mut self, behavior: crate::runtime::UnhandledPanic) -> &mut Self { // TODO: This should be set as a builder - Arc::get_mut(&mut self.context.shared) - .expect("TODO: we shouldn't panic") + Rc::get_mut(&mut self.context) + .and_then(|ctx| Arc::get_mut(&mut ctx.shared)) + .expect("Unhandled Panic behavior modified after starting LocalSet") .unhandled_panic = behavior; self } @@ -769,23 +821,33 @@ impl Future for RunUntil<'_, T> { } } +fn clone_rc(rc: &Cell>>) -> Option> { + let value = rc.take(); + let cloned = value.clone(); + rc.set(value); + cloned +} + impl Shared { /// Schedule the provided task on the scheduler. fn schedule(&self, task: task::Notified>) { - CURRENT.with(|maybe_cx| match maybe_cx { - Some(cx) if cx.shared.ptr_eq(self) => { - cx.queue.push_back(task); - } - _ => { - // First check whether the queue is still there (if not, the - // LocalSet is dropped). Then push to it if so, and if not, - // do nothing. - let mut lock = self.queue.lock(); - - if let Some(queue) = lock.as_mut() { - queue.push_back(task); - drop(lock); - self.waker.wake(); + CURRENT.with(|maybe_cx| { + let ctx = clone_rc(maybe_cx); + match ctx { + Some(cx) if cx.shared.ptr_eq(self) => { + cx.queue.push_back(task); + } + _ => { + // First check whether the queue is still there (if not, the + // LocalSet is dropped). Then push to it if so, and if not, + // do nothing. + let mut lock = self.queue.lock(); + + if let Some(queue) = lock.as_mut() { + queue.push_back(task); + drop(lock); + self.waker.wake(); + } } } }); @@ -799,9 +861,14 @@ impl Shared { impl task::Schedule for Arc { fn release(&self, task: &Task) -> Option> { CURRENT.with(|maybe_cx| { - let cx = maybe_cx.expect("scheduler context missing"); - assert!(cx.shared.ptr_eq(self)); - cx.owned.remove(task) + let ctx = clone_rc(maybe_cx); + match ctx { + None => panic!("scheduler context missing"), + Some(cx) => { + assert!(cx.shared.ptr_eq(self)); + cx.owned.remove(task) + } + } }) } @@ -821,13 +888,15 @@ impl task::Schedule for Arc { // This hook is only called from within the runtime, so // `CURRENT` should match with `&self`, i.e. there is no // opportunity for a nested scheduler to be called. - CURRENT.with(|maybe_cx| match maybe_cx { + CURRENT.with(|maybe_cx| { + let ctx = clone_rc(maybe_cx); + match ctx { Some(cx) if Arc::ptr_eq(self, &cx.shared) => { cx.unhandled_panic.set(true); cx.owned.close_and_shutdown_all(); } _ => unreachable!("runtime core not set in CURRENT thread-local"), - }) + }}) } } } diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs index 8ea73fb48de..409cd1327b1 100644 --- a/tokio/src/task/mod.rs +++ b/tokio/src/task/mod.rs @@ -297,7 +297,7 @@ cfg_rt! { } mod local; - pub use local::{spawn_local, LocalSet}; + pub use local::{spawn_local, LocalSet, LocalEnterGuard}; mod task_local; pub use task_local::LocalKey; diff --git a/tokio/tests/task_local_set.rs b/tokio/tests/task_local_set.rs index 2bd5b17b84d..0cdded4a6d1 100644 --- a/tokio/tests/task_local_set.rs +++ b/tokio/tests/task_local_set.rs @@ -126,6 +126,21 @@ async fn local_threadpool_timer() { }) .await; } +#[test] +fn enter_guard_spawn() { + let local = LocalSet::new(); + let _guard = local.enter(); + // Run the local task set. + + let join = task::spawn_local(async { true }); + let rt = runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + local.block_on(&rt, async move { + assert!(join.await.unwrap()); + }); +} #[test] // This will panic, since the thread that calls `block_on` cannot use