diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index 4427c8a2efc..d9e8428a837 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -16,7 +16,15 @@ cfg_rt! { struct Context { /// Handle to the runtime scheduler running on the current thread. #[cfg(feature = "rt")] - scheduler: RefCell>, + handle: RefCell>, + + /// Tracks if the current thread is currently driving a runtime. + /// Note, that if this is set to "entered", the current scheduler + /// handle may not reference the runtime currently executing. This + /// is because other runtime handles may be set to current from + /// within a runtime. + #[cfg(feature = "rt")] + runtime: Cell, #[cfg(any(feature = "rt", feature = "macros"))] rng: FastRand, @@ -29,19 +37,27 @@ struct Context { tokio_thread_local! { static CONTEXT: Context = { Context { + /// Tracks the current runtime handle to use when spawning, + /// accessing drivers, etc... + #[cfg(feature = "rt")] + handle: RefCell::new(None), + + /// Tracks if the current thread is currently driving a runtime. + /// Note, that if this is set to "entered", the current scheduler + /// handle may not reference the runtime currently executing. This + /// is because other runtime handles may be set to current from + /// within a runtime. #[cfg(feature = "rt")] - scheduler: RefCell::new(None), + runtime: Cell::new(EnterRuntime::NotEntered), #[cfg(any(feature = "rt", feature = "macros"))] rng: FastRand::new(RngSeed::new()), + budget: Cell::new(coop::Budget::unconstrained()), } } } -#[cfg(feature = "rt")] -tokio_thread_local!(static ENTERED: Cell = const { Cell::new(EnterRuntime::NotEntered) }); - #[cfg(feature = "macros")] pub(crate) fn thread_rng_n(n: u32) -> u32 { CONTEXT.with(|ctx| ctx.rng.fastrand_n(n)) @@ -86,7 +102,7 @@ cfg_rt! { pub(crate) struct DisallowBlockInPlaceGuard(bool); pub(crate) fn try_current() -> Result { - match CONTEXT.try_with(|ctx| ctx.scheduler.borrow().clone()) { + match CONTEXT.try_with(|ctx| ctx.handle.borrow().clone()) { Ok(Some(handle)) => Ok(handle), Ok(None) => Err(TryCurrentError::new_no_context()), Err(_access_error) => Err(TryCurrentError::new_thread_local_destroyed()), @@ -97,17 +113,7 @@ cfg_rt! { /// /// [`Handle`]: crate::runtime::scheduler::Handle pub(crate) fn try_set_current(handle: &scheduler::Handle) -> Option { - let rng_seed = handle.seed_generator().next_seed(); - - CONTEXT.try_with(|ctx| { - let old_handle = ctx.scheduler.borrow_mut().replace(handle.clone()); - let old_seed = ctx.rng.replace_seed(rng_seed); - - SetCurrentGuard { - old_handle, - old_seed, - } - }).ok() + CONTEXT.try_with(|ctx| ctx.set_current(handle)).ok() } @@ -133,11 +139,11 @@ cfg_rt! { /// Tries to enter a runtime context, returns `None` if already in a runtime /// context. fn try_enter_runtime(allow_block_in_place: bool) -> Option { - ENTERED.with(|c| { - if c.get().is_entered() { + CONTEXT.with(|c| { + if c.runtime.get().is_entered() { None } else { - c.set(EnterRuntime::Entered { allow_block_in_place }); + c.runtime.set(EnterRuntime::Entered { allow_block_in_place }); Some(EnterRuntimeGuard { blocking: BlockingRegionGuard::new(), }) @@ -146,23 +152,27 @@ cfg_rt! { } pub(crate) fn try_enter_blocking_region() -> Option { - ENTERED.with(|c| { - if c.get().is_entered() { + CONTEXT.try_with(|c| { + if c.runtime.get().is_entered() { None } else { Some(BlockingRegionGuard::new()) } - }) + // If accessing the thread-local fails, the thread is terminating + // and thread-locals are being destroyed. Because we don't know if + // we are currently in a runtime or not, we default to being + // permissive. + }).unwrap_or_else(|_| Some(BlockingRegionGuard::new())) } /// Disallows blocking in the current runtime context until the guard is dropped. pub(crate) fn disallow_block_in_place() -> DisallowBlockInPlaceGuard { - let reset = ENTERED.with(|c| { + let reset = CONTEXT.with(|c| { if let EnterRuntime::Entered { allow_block_in_place: true, - } = c.get() + } = c.runtime.get() { - c.set(EnterRuntime::Entered { + c.runtime.set(EnterRuntime::Entered { allow_block_in_place: false, }); true @@ -170,13 +180,28 @@ cfg_rt! { false } }); + DisallowBlockInPlaceGuard(reset) } + impl Context { + fn set_current(&self, handle: &scheduler::Handle) -> SetCurrentGuard { + let rng_seed = handle.seed_generator().next_seed(); + + let old_handle = self.handle.borrow_mut().replace(handle.clone()); + let old_seed = self.rng.replace_seed(rng_seed); + + SetCurrentGuard { + old_handle, + old_seed, + } + } + } + impl Drop for SetCurrentGuard { fn drop(&mut self) { CONTEXT.with(|ctx| { - *ctx.scheduler.borrow_mut() = self.old_handle.take(); + *ctx.handle.borrow_mut() = self.old_handle.take(); ctx.rng.replace_seed(self.old_seed.clone()); }); } @@ -190,9 +215,9 @@ cfg_rt! { impl Drop for EnterRuntimeGuard { fn drop(&mut self) { - ENTERED.with(|c| { - assert!(c.get().is_entered()); - c.set(EnterRuntime::NotEntered); + CONTEXT.with(|c| { + assert!(c.runtime.get().is_entered()); + c.runtime.set(EnterRuntime::NotEntered); }); } } @@ -253,12 +278,12 @@ cfg_rt! { fn drop(&mut self) { if self.0 { // XXX: Do we want some kind of assertion here, or is "best effort" okay? - ENTERED.with(|c| { + CONTEXT.with(|c| { if let EnterRuntime::Entered { allow_block_in_place: false, - } = c.get() + } = c.runtime.get() { - c.set(EnterRuntime::Entered { + c.runtime.set(EnterRuntime::Entered { allow_block_in_place: true, }); } @@ -284,7 +309,7 @@ cfg_rt! { cfg_rt_multi_thread! { /// Returns true if in a runtime context. pub(crate) fn current_enter_context() -> EnterRuntime { - ENTERED.with(|c| c.get()) + CONTEXT.with(|c| c.runtime.get()) } pub(crate) fn exit_runtime R, R>(f: F) -> R { @@ -293,17 +318,17 @@ cfg_rt_multi_thread! { impl Drop for Reset { fn drop(&mut self) { - ENTERED.with(|c| { - assert!(!c.get().is_entered(), "closure claimed permanent executor"); - c.set(self.0); + CONTEXT.with(|c| { + assert!(!c.runtime.get().is_entered(), "closure claimed permanent executor"); + c.runtime.set(self.0); }); } } - let was = ENTERED.with(|c| { - let e = c.get(); + let was = CONTEXT.with(|c| { + let e = c.runtime.get(); assert!(e.is_entered(), "asked to exit when not entered"); - c.set(EnterRuntime::NotEntered); + c.runtime.set(EnterRuntime::NotEntered); e });