diff --git a/tokio/src/future/block_on.rs b/tokio/src/future/block_on.rs index aff0075f98d..fedcdacd614 100644 --- a/tokio/src/future/block_on.rs +++ b/tokio/src/future/block_on.rs @@ -3,7 +3,7 @@ use std::future::Future; cfg_rt! { #[track_caller] pub(crate) fn block_on(f: F) -> F::Output { - let mut e = crate::runtime::enter::try_enter_blocking_region().expect( + let mut e = crate::runtime::context::try_enter_blocking_region().expect( "Cannot block the current thread from within a runtime. This \ happens because a functionattempted to block the current \ thread while the thread is being used to drive asynchronous \ diff --git a/tokio/src/runtime/blocking/shutdown.rs b/tokio/src/runtime/blocking/shutdown.rs index 59126cf83c3..fe5abae076d 100644 --- a/tokio/src/runtime/blocking/shutdown.rs +++ b/tokio/src/runtime/blocking/shutdown.rs @@ -35,7 +35,7 @@ impl Receiver { /// /// If the timeout has elapsed, it returns `false`, otherwise it returns `true`. pub(crate) fn wait(&mut self, timeout: Option) -> bool { - use crate::runtime::enter::try_enter_blocking_region; + use crate::runtime::context::try_enter_blocking_region; if timeout == Some(Duration::from_nanos(0)) { return false; diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index 7c92019d75f..4427c8a2efc 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -7,7 +7,10 @@ use crate::util::rand::{FastRand, RngSeed}; cfg_rt! { use crate::runtime::scheduler; + use std::cell::RefCell; + use std::marker::PhantomData; + use std::time::Duration; } struct Context { @@ -17,6 +20,7 @@ struct Context { #[cfg(any(feature = "rt", feature = "macros"))] rng: FastRand, + /// Tracks the amount of "work" a task may still do before yielding back to /// the sheduler budget: Cell, @@ -35,6 +39,9 @@ tokio_thread_local! { } } +#[cfg(feature = "rt")] +tokio_thread_local!(static ENTERED: Cell = const { Cell::new(EnterRuntime::NotEntered) }); + #[cfg(feature = "macros")] pub(crate) fn thread_rng_n(n: u32) -> u32 { CONTEXT.with(|ctx| ctx.rng.fastrand_n(n)) @@ -45,14 +52,39 @@ pub(super) fn budget(f: impl FnOnce(&Cell) -> R) -> R { } cfg_rt! { + use crate::loom::thread::AccessError; use crate::runtime::TryCurrentError; + use std::fmt; + + #[derive(Debug, Clone, Copy)] + pub(crate) enum EnterRuntime { + /// Currently in a runtime context. + #[cfg_attr(not(feature = "rt"), allow(dead_code))] + Entered { allow_block_in_place: bool }, + + /// Not in a runtime context **or** a blocking region. + NotEntered, + } + #[derive(Debug)] pub(crate) struct SetCurrentGuard { old_handle: Option, old_seed: RngSeed, } + /// Guard tracking that a caller has entered a runtime context. + pub(crate) struct EnterRuntimeGuard { + pub(crate) blocking: BlockingRegionGuard, + } + + /// Guard tracking that a caller has entered a blocking region. + pub(crate) struct BlockingRegionGuard { + _p: PhantomData>, + } + + pub(crate) struct DisallowBlockInPlaceGuard(bool); + pub(crate) fn try_current() -> Result { match CONTEXT.try_with(|ctx| ctx.scheduler.borrow().clone()) { Ok(Some(handle)) => Ok(handle), @@ -78,6 +110,69 @@ cfg_rt! { }).ok() } + + /// Marks the current thread as being within the dynamic extent of an + /// executor. + #[track_caller] + pub(crate) fn enter_runtime(handle: &scheduler::Handle, allow_block_in_place: bool) -> EnterRuntimeGuard { + if let Some(enter) = try_enter_runtime(allow_block_in_place) { + // Set the current runtime handle. This should not fail. A later + // cleanup will remove the unwrap(). + try_set_current(handle).unwrap(); + return enter; + } + + panic!( + "Cannot start a runtime from within a runtime. This happens \ + because a function (like `block_on`) attempted to block the \ + current thread while the thread is being used to drive \ + asynchronous tasks." + ); + } + + /// Tries to enter a runtime context, returns `None` if already in a runtime + /// context. + fn try_enter_runtime(allow_block_in_place: bool) -> Option { + ENTERED.with(|c| { + if c.get().is_entered() { + None + } else { + c.set(EnterRuntime::Entered { allow_block_in_place }); + Some(EnterRuntimeGuard { + blocking: BlockingRegionGuard::new(), + }) + } + }) + } + + pub(crate) fn try_enter_blocking_region() -> Option { + ENTERED.with(|c| { + if c.get().is_entered() { + None + } else { + Some(BlockingRegionGuard::new()) + } + }) + } + + /// Disallows blocking in the current runtime context until the guard is dropped. + pub(crate) fn disallow_block_in_place() -> DisallowBlockInPlaceGuard { + let reset = ENTERED.with(|c| { + if let EnterRuntime::Entered { + allow_block_in_place: true, + } = c.get() + { + c.set(EnterRuntime::Entered { + allow_block_in_place: false, + }); + true + } else { + false + } + }); + DisallowBlockInPlaceGuard(reset) + } + impl Drop for SetCurrentGuard { fn drop(&mut self) { CONTEXT.with(|ctx| { @@ -86,4 +181,134 @@ cfg_rt! { }); } } + + impl fmt::Debug for EnterRuntimeGuard { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Enter").finish() + } + } + + impl Drop for EnterRuntimeGuard { + fn drop(&mut self) { + ENTERED.with(|c| { + assert!(c.get().is_entered()); + c.set(EnterRuntime::NotEntered); + }); + } + } + + impl BlockingRegionGuard { + fn new() -> BlockingRegionGuard { + BlockingRegionGuard { _p: PhantomData } + } + /// Blocks the thread on the specified future, returning the value with + /// which that future completes. + pub(crate) fn block_on(&mut self, f: F) -> Result + where + F: std::future::Future, + { + use crate::runtime::park::CachedParkThread; + + let mut park = CachedParkThread::new(); + park.block_on(f) + } + + /// Blocks the thread on the specified future for **at most** `timeout` + /// + /// If the future completes before `timeout`, the result is returned. If + /// `timeout` elapses, then `Err` is returned. + pub(crate) fn block_on_timeout(&mut self, f: F, timeout: Duration) -> Result + where + F: std::future::Future, + { + use crate::runtime::park::CachedParkThread; + use std::task::Context; + use std::task::Poll::Ready; + use std::time::Instant; + + let mut park = CachedParkThread::new(); + let waker = park.waker().map_err(|_| ())?; + let mut cx = Context::from_waker(&waker); + + pin!(f); + let when = Instant::now() + timeout; + + loop { + if let Ready(v) = crate::runtime::coop::budget(|| f.as_mut().poll(&mut cx)) { + return Ok(v); + } + + let now = Instant::now(); + + if now >= when { + return Err(()); + } + + park.park_timeout(when - now); + } + } + } + + impl Drop for DisallowBlockInPlaceGuard { + fn drop(&mut self) { + if self.0 { + // XXX: Do we want some kind of assertion here, or is "best effort" okay? + ENTERED.with(|c| { + if let EnterRuntime::Entered { + allow_block_in_place: false, + } = c.get() + { + c.set(EnterRuntime::Entered { + allow_block_in_place: true, + }); + } + }) + } + } + } + + impl EnterRuntime { + pub(crate) fn is_entered(self) -> bool { + matches!(self, EnterRuntime::Entered { .. }) + } + } +} + +// Forces the current "entered" state to be cleared while the closure +// is executed. +// +// # Warning +// +// This is hidden for a reason. Do not use without fully understanding +// executors. Misusing can easily cause your program to deadlock. +cfg_rt_multi_thread! { + /// Returns true if in a runtime context. + pub(crate) fn current_enter_context() -> EnterRuntime { + ENTERED.with(|c| c.get()) + } + + pub(crate) fn exit_runtime R, R>(f: F) -> R { + // Reset in case the closure panics + struct Reset(EnterRuntime); + + impl Drop for Reset { + fn drop(&mut self) { + ENTERED.with(|c| { + assert!(!c.get().is_entered(), "closure claimed permanent executor"); + c.set(self.0); + }); + } + } + + let was = ENTERED.with(|c| { + let e = c.get(); + assert!(e.is_entered(), "asked to exit when not entered"); + c.set(EnterRuntime::NotEntered); + e + }); + + let _reset = Reset(was); + // dropping _reset after f() will reset ENTERED + f() + } } diff --git a/tokio/src/runtime/enter.rs b/tokio/src/runtime/enter.rs deleted file mode 100644 index 242de38a062..00000000000 --- a/tokio/src/runtime/enter.rs +++ /dev/null @@ -1,233 +0,0 @@ -use crate::runtime::scheduler; - -use std::cell::{Cell, RefCell}; -use std::fmt; -use std::marker::PhantomData; - -#[derive(Debug, Clone, Copy)] -pub(crate) enum EnterContext { - /// Currently in a runtime context. - #[cfg_attr(not(feature = "rt"), allow(dead_code))] - Entered { allow_block_in_place: bool }, - - /// Not in a runtime context **or** a blocking region. - NotEntered, -} - -impl EnterContext { - pub(crate) fn is_entered(self) -> bool { - matches!(self, EnterContext::Entered { .. }) - } -} - -tokio_thread_local!(static ENTERED: Cell = const { Cell::new(EnterContext::NotEntered) }); - -/// Guard tracking that a caller has entered a runtime context. -pub(crate) struct EnterRuntimeGuard { - pub(crate) blocking: BlockingRegionGuard, -} - -/// Guard tracking that a caller has entered a blocking region. -pub(crate) struct BlockingRegionGuard { - _p: PhantomData>, -} - -cfg_rt! { - use crate::runtime::context; - - use std::time::Duration; - - /// Marks the current thread as being within the dynamic extent of an - /// executor. - #[track_caller] - pub(crate) fn enter_runtime(handle: &scheduler::Handle, allow_block_in_place: bool) -> EnterRuntimeGuard { - if let Some(enter) = try_enter_runtime(allow_block_in_place) { - // Set the current runtime handle. This should not fail. A later - // cleanup will remove the unwrap(). - context::try_set_current(handle).unwrap(); - return enter; - } - - panic!( - "Cannot start a runtime from within a runtime. This happens \ - because a function (like `block_on`) attempted to block the \ - current thread while the thread is being used to drive \ - asynchronous tasks." - ); - } - - /// Tries to enter a runtime context, returns `None` if already in a runtime - /// context. - fn try_enter_runtime(allow_block_in_place: bool) -> Option { - ENTERED.with(|c| { - if c.get().is_entered() { - None - } else { - c.set(EnterContext::Entered { allow_block_in_place }); - Some(EnterRuntimeGuard { - blocking: BlockingRegionGuard::new(), - }) - } - }) - } - - pub(crate) fn try_enter_blocking_region() -> Option { - ENTERED.with(|c| { - if c.get().is_entered() { - None - } else { - Some(BlockingRegionGuard::new()) - } - }) - } -} - -// Forces the current "entered" state to be cleared while the closure -// is executed. -// -// # Warning -// -// This is hidden for a reason. Do not use without fully understanding -// executors. Misusing can easily cause your program to deadlock. -cfg_rt_multi_thread! { - pub(crate) fn exit_runtime R, R>(f: F) -> R { - // Reset in case the closure panics - struct Reset(EnterContext); - impl Drop for Reset { - fn drop(&mut self) { - ENTERED.with(|c| { - assert!(!c.get().is_entered(), "closure claimed permanent executor"); - c.set(self.0); - }); - } - } - - let was = ENTERED.with(|c| { - let e = c.get(); - assert!(e.is_entered(), "asked to exit when not entered"); - c.set(EnterContext::NotEntered); - e - }); - - let _reset = Reset(was); - // dropping _reset after f() will reset ENTERED - f() - } -} - -cfg_rt! { - /// Disallows blocking in the current runtime context until the guard is dropped. - pub(crate) fn disallow_block_in_place() -> DisallowBlockInPlaceGuard { - let reset = ENTERED.with(|c| { - if let EnterContext::Entered { - allow_block_in_place: true, - } = c.get() - { - c.set(EnterContext::Entered { - allow_block_in_place: false, - }); - true - } else { - false - } - }); - DisallowBlockInPlaceGuard(reset) - } - - pub(crate) struct DisallowBlockInPlaceGuard(bool); - impl Drop for DisallowBlockInPlaceGuard { - fn drop(&mut self) { - if self.0 { - // XXX: Do we want some kind of assertion here, or is "best effort" okay? - ENTERED.with(|c| { - if let EnterContext::Entered { - allow_block_in_place: false, - } = c.get() - { - c.set(EnterContext::Entered { - allow_block_in_place: true, - }); - } - }) - } - } - } -} - -cfg_rt_multi_thread! { - /// Returns true if in a runtime context. - pub(crate) fn context() -> EnterContext { - ENTERED.with(|c| c.get()) - } -} - -cfg_rt! { - use crate::loom::thread::AccessError; - - impl BlockingRegionGuard { - fn new() -> BlockingRegionGuard { - BlockingRegionGuard { _p: PhantomData } - } - /// Blocks the thread on the specified future, returning the value with - /// which that future completes. - pub(crate) fn block_on(&mut self, f: F) -> Result - where - F: std::future::Future, - { - use crate::runtime::park::CachedParkThread; - - let mut park = CachedParkThread::new(); - park.block_on(f) - } - - /// Blocks the thread on the specified future for **at most** `timeout` - /// - /// If the future completes before `timeout`, the result is returned. If - /// `timeout` elapses, then `Err` is returned. - pub(crate) fn block_on_timeout(&mut self, f: F, timeout: Duration) -> Result - where - F: std::future::Future, - { - use crate::runtime::park::CachedParkThread; - use std::task::Context; - use std::task::Poll::Ready; - use std::time::Instant; - - let mut park = CachedParkThread::new(); - let waker = park.waker().map_err(|_| ())?; - let mut cx = Context::from_waker(&waker); - - pin!(f); - let when = Instant::now() + timeout; - - loop { - if let Ready(v) = crate::runtime::coop::budget(|| f.as_mut().poll(&mut cx)) { - return Ok(v); - } - - let now = Instant::now(); - - if now >= when { - return Err(()); - } - - park.park_timeout(when - now); - } - } - } -} - -impl fmt::Debug for EnterRuntimeGuard { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Enter").finish() - } -} - -impl Drop for EnterRuntimeGuard { - fn drop(&mut self) { - ENTERED.with(|c| { - assert!(c.get().is_entered()); - c.set(EnterContext::NotEntered); - }); - } -} diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index 1b9ea8bf2b6..da47ecb27b2 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -258,7 +258,7 @@ impl Handle { // Enter the runtime context. This sets the current driver handles and // prevents blocking an existing runtime. - let mut enter = crate::runtime::enter::enter_runtime(&self.inner, true); + let mut enter = context::enter_runtime(&self.inner, true); // Block on the future enter diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index ef86bf37660..6e527801ec6 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -204,8 +204,6 @@ cfg_signal_internal_and_unix! { } cfg_rt! { - pub(crate) mod enter; - pub(crate) mod task; mod config; @@ -230,8 +228,6 @@ cfg_rt! { pub use crate::util::rand::RngSeed; } - use self::enter::enter_runtime; - mod handle; pub use handle::{EnterGuard, Handle, TryCurrentError}; diff --git a/tokio/src/runtime/scheduler/current_thread.rs b/tokio/src/runtime/scheduler/current_thread.rs index 694bc8dae95..d874448c55e 100644 --- a/tokio/src/runtime/scheduler/current_thread.rs +++ b/tokio/src/runtime/scheduler/current_thread.rs @@ -143,7 +143,7 @@ impl CurrentThread { pub(crate) fn block_on(&self, handle: &scheduler::Handle, future: F) -> F::Output { pin!(future); - let mut enter = crate::runtime::enter_runtime(handle, false); + let mut enter = crate::runtime::context::enter_runtime(handle, false); let handle = handle.as_current_thread(); // Attempt to steal the scheduler core and block_on the future if we can diff --git a/tokio/src/runtime/scheduler/multi_thread/mod.rs b/tokio/src/runtime/scheduler/multi_thread/mod.rs index 11a4667ea3f..47cd1f3d7ae 100644 --- a/tokio/src/runtime/scheduler/multi_thread/mod.rs +++ b/tokio/src/runtime/scheduler/multi_thread/mod.rs @@ -62,7 +62,7 @@ impl MultiThread { where F: Future, { - let mut enter = crate::runtime::enter_runtime(handle, true); + let mut enter = crate::runtime::context::enter_runtime(handle, true); enter .blocking .block_on(future) diff --git a/tokio/src/runtime/scheduler/multi_thread/worker.rs b/tokio/src/runtime/scheduler/multi_thread/worker.rs index e1bde3b4251..ee0765230c2 100644 --- a/tokio/src/runtime/scheduler/multi_thread/worker.rs +++ b/tokio/src/runtime/scheduler/multi_thread/worker.rs @@ -58,7 +58,7 @@ use crate::loom::sync::{Arc, Mutex}; use crate::runtime; -use crate::runtime::enter::EnterContext; +use crate::runtime::context; use crate::runtime::scheduler::multi_thread::{queue, Handle, Idle, Parker, Unparker}; use crate::runtime::task::{Inject, OwnedTasks}; use crate::runtime::{ @@ -276,14 +276,17 @@ where let mut had_entered = false; let setup_result = CURRENT.with(|maybe_cx| { - match (crate::runtime::enter::context(), maybe_cx.is_some()) { - (EnterContext::Entered { .. }, true) => { + match ( + crate::runtime::context::current_enter_context(), + maybe_cx.is_some(), + ) { + (context::EnterRuntime::Entered { .. }, true) => { // We are on a thread pool runtime thread, so we just need to // set up blocking. had_entered = true; } ( - EnterContext::Entered { + context::EnterRuntime::Entered { allow_block_in_place, }, false, @@ -302,12 +305,12 @@ where ); } } - (EnterContext::NotEntered, true) => { + (context::EnterRuntime::NotEntered, true) => { // This is a nested call to block_in_place (we already exited). // All the necessary setup has already been done. return Ok(()); } - (EnterContext::NotEntered, false) => { + (context::EnterRuntime::NotEntered, false) => { // We are outside of the tokio runtime, so blocking is fine. // We can also skip all of the thread pool blocking setup steps. return Ok(()); @@ -350,7 +353,7 @@ where // constrained by task budgets. let _reset = Reset(coop::stop()); - crate::runtime::enter::exit_runtime(f) + crate::runtime::context::exit_runtime(f) } else { f() } @@ -373,7 +376,7 @@ fn run(worker: Arc) { }; let handle = scheduler::Handle::MultiThread(worker.handle.clone()); - let _enter = crate::runtime::enter_runtime(&handle, true); + let _enter = crate::runtime::context::enter_runtime(&handle, true); // Set the worker context. let cx = Context { diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index 25fc7bab672..38ed22b4f6b 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -890,7 +890,7 @@ impl Future for RunUntil<'_, T> { .waker .register_by_ref(cx.waker()); - let _no_blocking = crate::runtime::enter::disallow_block_in_place(); + let _no_blocking = crate::runtime::context::disallow_block_in_place(); let f = me.future; if let Poll::Ready(output) = f.poll(cx) {