diff --git a/tokio/src/future/block_on.rs b/tokio/src/future/block_on.rs index a624db53538..aff0075f98d 100644 --- a/tokio/src/future/block_on.rs +++ b/tokio/src/future/block_on.rs @@ -3,7 +3,12 @@ use std::future::Future; cfg_rt! { #[track_caller] pub(crate) fn block_on(f: F) -> F::Output { - let mut e = crate::runtime::enter::enter(false); + let mut e = crate::runtime::enter::try_enter_blocking_region().expect( + "Cannot block the current thread from within a runtime. This \ + happens because a functionattempted to block the current \ + thread while the thread is being used to drive asynchronous \ + tasks." + ); e.block_on(f).unwrap() } } diff --git a/tokio/src/runtime/blocking/shutdown.rs b/tokio/src/runtime/blocking/shutdown.rs index e6f4674183e..59126cf83c3 100644 --- a/tokio/src/runtime/blocking/shutdown.rs +++ b/tokio/src/runtime/blocking/shutdown.rs @@ -35,13 +35,13 @@ impl Receiver { /// /// If the timeout has elapsed, it returns `false`, otherwise it returns `true`. pub(crate) fn wait(&mut self, timeout: Option) -> bool { - use crate::runtime::enter::try_enter; + use crate::runtime::enter::try_enter_blocking_region; if timeout == Some(Duration::from_nanos(0)) { return false; } - let mut e = match try_enter(false) { + let mut e = match try_enter_blocking_region() { Some(enter) => enter, _ => { if std::thread::panicking() { diff --git a/tokio/src/runtime/enter.rs b/tokio/src/runtime/enter.rs index 6089b42c831..242de38a062 100644 --- a/tokio/src/runtime/enter.rs +++ b/tokio/src/runtime/enter.rs @@ -1,13 +1,16 @@ +use crate::runtime::scheduler; + use std::cell::{Cell, RefCell}; use std::fmt; use std::marker::PhantomData; #[derive(Debug, Clone, Copy)] pub(crate) enum EnterContext { + /// Currently in a runtime context. #[cfg_attr(not(feature = "rt"), allow(dead_code))] - Entered { - allow_block_in_place: bool, - }, + Entered { allow_block_in_place: bool }, + + /// Not in a runtime context **or** a blocking region. NotEntered, } @@ -19,19 +22,29 @@ impl EnterContext { tokio_thread_local!(static ENTERED: Cell = const { Cell::new(EnterContext::NotEntered) }); -/// Represents an executor context. -pub(crate) struct Enter { +/// Guard tracking that a caller has entered a runtime context. +pub(crate) struct EnterRuntimeGuard { + pub(crate) blocking: BlockingRegionGuard, +} + +/// Guard tracking that a caller has entered a blocking region. +pub(crate) struct BlockingRegionGuard { _p: PhantomData>, } cfg_rt! { + use crate::runtime::context; + use std::time::Duration; /// Marks the current thread as being within the dynamic extent of an /// executor. #[track_caller] - pub(crate) fn enter(allow_block_in_place: bool) -> Enter { - if let Some(enter) = try_enter(allow_block_in_place) { + pub(crate) fn enter_runtime(handle: &scheduler::Handle, allow_block_in_place: bool) -> EnterRuntimeGuard { + if let Some(enter) = try_enter_runtime(allow_block_in_place) { + // Set the current runtime handle. This should not fail. A later + // cleanup will remove the unwrap(). + context::try_set_current(handle).unwrap(); return enter; } @@ -45,13 +58,25 @@ cfg_rt! { /// Tries to enter a runtime context, returns `None` if already in a runtime /// context. - pub(crate) fn try_enter(allow_block_in_place: bool) -> Option { + fn try_enter_runtime(allow_block_in_place: bool) -> Option { ENTERED.with(|c| { if c.get().is_entered() { None } else { c.set(EnterContext::Entered { allow_block_in_place }); - Some(Enter { _p: PhantomData }) + Some(EnterRuntimeGuard { + blocking: BlockingRegionGuard::new(), + }) + } + }) + } + + pub(crate) fn try_enter_blocking_region() -> Option { + ENTERED.with(|c| { + if c.get().is_entered() { + None + } else { + Some(BlockingRegionGuard::new()) } }) } @@ -65,7 +90,7 @@ cfg_rt! { // This is hidden for a reason. Do not use without fully understanding // executors. Misusing can easily cause your program to deadlock. cfg_rt_multi_thread! { - pub(crate) fn exit R, R>(f: F) -> R { + pub(crate) fn exit_runtime R, R>(f: F) -> R { // Reset in case the closure panics struct Reset(EnterContext); impl Drop for Reset { @@ -139,7 +164,10 @@ cfg_rt_multi_thread! { cfg_rt! { use crate::loom::thread::AccessError; - impl Enter { + impl BlockingRegionGuard { + fn new() -> BlockingRegionGuard { + BlockingRegionGuard { _p: PhantomData } + } /// Blocks the thread on the specified future, returning the value with /// which that future completes. pub(crate) fn block_on(&mut self, f: F) -> Result @@ -189,13 +217,13 @@ cfg_rt! { } } -impl fmt::Debug for Enter { +impl fmt::Debug for EnterRuntimeGuard { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Enter").finish() } } -impl Drop for Enter { +impl Drop for EnterRuntimeGuard { fn drop(&mut self) { ENTERED.with(|c| { assert!(c.get().is_entered()); diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index a4f34b87d4c..1b9ea8bf2b6 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -256,14 +256,13 @@ impl Handle { let future = crate::util::trace::task(future, "block_on", None, super::task::Id::next().as_u64()); - // Enter the **runtime** context. This configures spawning, the current I/O driver, ... - let _rt_enter = self.enter(); - - // Enter a **blocking** context. This prevents blocking from a runtime. - let mut blocking_enter = crate::runtime::enter(true); + // Enter the runtime context. This sets the current driver handles and + // prevents blocking an existing runtime. + let mut enter = crate::runtime::enter::enter_runtime(&self.inner, true); // Block on the future - blocking_enter + enter + .blocking .block_on(future) .expect("failed to park thread") } diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 92e2c387680..ef86bf37660 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -230,7 +230,7 @@ cfg_rt! { pub use crate::util::rand::RngSeed; } - use self::enter::enter; + use self::enter::enter_runtime; mod handle; pub use handle::{EnterGuard, Handle, TryCurrentError}; diff --git a/tokio/src/runtime/scheduler/current_thread.rs b/tokio/src/runtime/scheduler/current_thread.rs index 778f93a84c5..6bcc1750a44 100644 --- a/tokio/src/runtime/scheduler/current_thread.rs +++ b/tokio/src/runtime/scheduler/current_thread.rs @@ -155,8 +155,13 @@ impl CurrentThread { #[track_caller] pub(crate) fn block_on(&self, future: F) -> F::Output { + use crate::runtime::scheduler; + pin!(future); + let handle = scheduler::Handle::CurrentThread(self.handle.clone()); + let mut enter = crate::runtime::enter_runtime(&handle, false); + // Attempt to steal the scheduler core and block_on the future if we can // there, otherwise, lets select on a notification that the core is // available or the future is complete. @@ -164,12 +169,11 @@ impl CurrentThread { if let Some(core) = self.take_core() { return core.block_on(future); } else { - let mut enter = crate::runtime::enter(false); - let notified = self.notify.notified(); pin!(notified); if let Some(out) = enter + .blocking .block_on(poll_fn(|cx| { if notified.as_mut().poll(cx).is_ready() { return Ready(None); @@ -522,7 +526,6 @@ impl CoreGuard<'_> { #[track_caller] fn block_on(self, future: F) -> F::Output { let ret = self.enter(|mut core, context| { - let _enter = crate::runtime::enter(false); let waker = Handle::waker_ref(&context.handle); let mut cx = std::task::Context::from_waker(&waker); diff --git a/tokio/src/runtime/scheduler/multi_thread/mod.rs b/tokio/src/runtime/scheduler/multi_thread/mod.rs index 8a1b5b6c483..2b96901cce0 100644 --- a/tokio/src/runtime/scheduler/multi_thread/mod.rs +++ b/tokio/src/runtime/scheduler/multi_thread/mod.rs @@ -20,7 +20,7 @@ use crate::loom::sync::Arc; use crate::runtime::{ blocking, driver::{self, Driver}, - Config, + scheduler, Config, }; use crate::util::RngSeedGenerator; @@ -73,8 +73,12 @@ impl MultiThread { where F: Future, { - let mut enter = crate::runtime::enter(true); - enter.block_on(future).expect("failed to park thread") + let handle = scheduler::Handle::MultiThread(self.handle.clone()); + let mut enter = crate::runtime::enter_runtime(&handle, true); + enter + .blocking + .block_on(future) + .expect("failed to park thread") } } diff --git a/tokio/src/runtime/scheduler/multi_thread/worker.rs b/tokio/src/runtime/scheduler/multi_thread/worker.rs index b21a11923d9..e1bde3b4251 100644 --- a/tokio/src/runtime/scheduler/multi_thread/worker.rs +++ b/tokio/src/runtime/scheduler/multi_thread/worker.rs @@ -62,7 +62,7 @@ use crate::runtime::enter::EnterContext; use crate::runtime::scheduler::multi_thread::{queue, Handle, Idle, Parker, Unparker}; use crate::runtime::task::{Inject, OwnedTasks}; use crate::runtime::{ - blocking, coop, driver, task, Config, MetricsBatch, SchedulerMetrics, WorkerMetrics, + blocking, coop, driver, scheduler, task, Config, MetricsBatch, SchedulerMetrics, WorkerMetrics, }; use crate::util::atomic_cell::AtomicCell; use crate::util::rand::{FastRand, RngSeedGenerator}; @@ -350,7 +350,7 @@ where // constrained by task budgets. let _reset = Reset(coop::stop()); - crate::runtime::enter::exit(f) + crate::runtime::enter::exit_runtime(f) } else { f() } @@ -372,14 +372,15 @@ fn run(worker: Arc) { None => return, }; + let handle = scheduler::Handle::MultiThread(worker.handle.clone()); + let _enter = crate::runtime::enter_runtime(&handle, true); + // Set the worker context. let cx = Context { worker, core: RefCell::new(None), }; - let _enter = crate::runtime::enter(true); - CURRENT.set(&cx, || { // This should always be an error. It only returns a `Result` to support // using `?` to short circuit.