From 67fa5ace557cd1e8f7dfe9f850854140c1df5c59 Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Thu, 3 Nov 2022 11:11:28 -0700 Subject: [PATCH] rt: unify entering a runtime with Handle::enter This is a frist step towards unifying the concepts of "entering a runtime" and setting `Handle::current`. Previously, these two operation were performed separately at each call site (runtime block_on, ...). This is error prone and also requires multiple accesses to the thread-local variable. Additionally, "entering the runtime" conflated the concept of entering a blocking region. For example, calling `mpsc::Receiver::recv_blocking` performed the "enter the runtime" step. This was done to prevent blocking a runtime as the operation will panic when called from an existing runtime. To untangle these concepts, the patch splits out each logical operation into its own call. In total, there are three "enter" operations: * `set_current_handle` * `enter_runtime` * `enter_blocking_region` There are some behavior changes with each of these functions, but they should not translate to public behavior changes. The most signifcant is `enter_blocking_region` does not change the value of the thread-local variable, which means the function can be re-entered. Since `enter_blocking_region` is an internal-only function and we do not actually re-enter, this has no public facing impact. Because `enter_runtime` takes a `&Handle` in order to combine the `set_current_handle` operation with entering a runtime, the patch exposes an annoyance with the current `scheduler::Handle` struct layout. A new instance of `scheduler::Handle` must be constructed at each call to `enter_runtime`. We can explore cleaning this up later. This patch also does not combine the "entered runtime" thread-local variable with the "context" thread-local variable. To keep the patch smaller, this has been punted to a follow up change. --- tokio/src/future/block_on.rs | 7 ++- tokio/src/runtime/blocking/shutdown.rs | 4 +- tokio/src/runtime/enter.rs | 54 ++++++++++++++----- tokio/src/runtime/handle.rs | 11 ++-- tokio/src/runtime/mod.rs | 2 +- tokio/src/runtime/scheduler/current_thread.rs | 9 ++-- .../src/runtime/scheduler/multi_thread/mod.rs | 10 ++-- .../runtime/scheduler/multi_thread/worker.rs | 9 ++-- 8 files changed, 73 insertions(+), 33 deletions(-) diff --git a/tokio/src/future/block_on.rs b/tokio/src/future/block_on.rs index a624db53538..aff0075f98d 100644 --- a/tokio/src/future/block_on.rs +++ b/tokio/src/future/block_on.rs @@ -3,7 +3,12 @@ use std::future::Future; cfg_rt! { #[track_caller] pub(crate) fn block_on(f: F) -> F::Output { - let mut e = crate::runtime::enter::enter(false); + let mut e = crate::runtime::enter::try_enter_blocking_region().expect( + "Cannot block the current thread from within a runtime. This \ + happens because a functionattempted to block the current \ + thread while the thread is being used to drive asynchronous \ + tasks." + ); e.block_on(f).unwrap() } } diff --git a/tokio/src/runtime/blocking/shutdown.rs b/tokio/src/runtime/blocking/shutdown.rs index e6f4674183e..59126cf83c3 100644 --- a/tokio/src/runtime/blocking/shutdown.rs +++ b/tokio/src/runtime/blocking/shutdown.rs @@ -35,13 +35,13 @@ impl Receiver { /// /// If the timeout has elapsed, it returns `false`, otherwise it returns `true`. pub(crate) fn wait(&mut self, timeout: Option) -> bool { - use crate::runtime::enter::try_enter; + use crate::runtime::enter::try_enter_blocking_region; if timeout == Some(Duration::from_nanos(0)) { return false; } - let mut e = match try_enter(false) { + let mut e = match try_enter_blocking_region() { Some(enter) => enter, _ => { if std::thread::panicking() { diff --git a/tokio/src/runtime/enter.rs b/tokio/src/runtime/enter.rs index 6089b42c831..242de38a062 100644 --- a/tokio/src/runtime/enter.rs +++ b/tokio/src/runtime/enter.rs @@ -1,13 +1,16 @@ +use crate::runtime::scheduler; + use std::cell::{Cell, RefCell}; use std::fmt; use std::marker::PhantomData; #[derive(Debug, Clone, Copy)] pub(crate) enum EnterContext { + /// Currently in a runtime context. #[cfg_attr(not(feature = "rt"), allow(dead_code))] - Entered { - allow_block_in_place: bool, - }, + Entered { allow_block_in_place: bool }, + + /// Not in a runtime context **or** a blocking region. NotEntered, } @@ -19,19 +22,29 @@ impl EnterContext { tokio_thread_local!(static ENTERED: Cell = const { Cell::new(EnterContext::NotEntered) }); -/// Represents an executor context. -pub(crate) struct Enter { +/// Guard tracking that a caller has entered a runtime context. +pub(crate) struct EnterRuntimeGuard { + pub(crate) blocking: BlockingRegionGuard, +} + +/// Guard tracking that a caller has entered a blocking region. +pub(crate) struct BlockingRegionGuard { _p: PhantomData>, } cfg_rt! { + use crate::runtime::context; + use std::time::Duration; /// Marks the current thread as being within the dynamic extent of an /// executor. #[track_caller] - pub(crate) fn enter(allow_block_in_place: bool) -> Enter { - if let Some(enter) = try_enter(allow_block_in_place) { + pub(crate) fn enter_runtime(handle: &scheduler::Handle, allow_block_in_place: bool) -> EnterRuntimeGuard { + if let Some(enter) = try_enter_runtime(allow_block_in_place) { + // Set the current runtime handle. This should not fail. A later + // cleanup will remove the unwrap(). + context::try_set_current(handle).unwrap(); return enter; } @@ -45,13 +58,25 @@ cfg_rt! { /// Tries to enter a runtime context, returns `None` if already in a runtime /// context. - pub(crate) fn try_enter(allow_block_in_place: bool) -> Option { + fn try_enter_runtime(allow_block_in_place: bool) -> Option { ENTERED.with(|c| { if c.get().is_entered() { None } else { c.set(EnterContext::Entered { allow_block_in_place }); - Some(Enter { _p: PhantomData }) + Some(EnterRuntimeGuard { + blocking: BlockingRegionGuard::new(), + }) + } + }) + } + + pub(crate) fn try_enter_blocking_region() -> Option { + ENTERED.with(|c| { + if c.get().is_entered() { + None + } else { + Some(BlockingRegionGuard::new()) } }) } @@ -65,7 +90,7 @@ cfg_rt! { // This is hidden for a reason. Do not use without fully understanding // executors. Misusing can easily cause your program to deadlock. cfg_rt_multi_thread! { - pub(crate) fn exit R, R>(f: F) -> R { + pub(crate) fn exit_runtime R, R>(f: F) -> R { // Reset in case the closure panics struct Reset(EnterContext); impl Drop for Reset { @@ -139,7 +164,10 @@ cfg_rt_multi_thread! { cfg_rt! { use crate::loom::thread::AccessError; - impl Enter { + impl BlockingRegionGuard { + fn new() -> BlockingRegionGuard { + BlockingRegionGuard { _p: PhantomData } + } /// Blocks the thread on the specified future, returning the value with /// which that future completes. pub(crate) fn block_on(&mut self, f: F) -> Result @@ -189,13 +217,13 @@ cfg_rt! { } } -impl fmt::Debug for Enter { +impl fmt::Debug for EnterRuntimeGuard { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Enter").finish() } } -impl Drop for Enter { +impl Drop for EnterRuntimeGuard { fn drop(&mut self) { ENTERED.with(|c| { assert!(c.get().is_entered()); diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index a4f34b87d4c..1b9ea8bf2b6 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -256,14 +256,13 @@ impl Handle { let future = crate::util::trace::task(future, "block_on", None, super::task::Id::next().as_u64()); - // Enter the **runtime** context. This configures spawning, the current I/O driver, ... - let _rt_enter = self.enter(); - - // Enter a **blocking** context. This prevents blocking from a runtime. - let mut blocking_enter = crate::runtime::enter(true); + // Enter the runtime context. This sets the current driver handles and + // prevents blocking an existing runtime. + let mut enter = crate::runtime::enter::enter_runtime(&self.inner, true); // Block on the future - blocking_enter + enter + .blocking .block_on(future) .expect("failed to park thread") } diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 92e2c387680..ef86bf37660 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -230,7 +230,7 @@ cfg_rt! { pub use crate::util::rand::RngSeed; } - use self::enter::enter; + use self::enter::enter_runtime; mod handle; pub use handle::{EnterGuard, Handle, TryCurrentError}; diff --git a/tokio/src/runtime/scheduler/current_thread.rs b/tokio/src/runtime/scheduler/current_thread.rs index 778f93a84c5..6bcc1750a44 100644 --- a/tokio/src/runtime/scheduler/current_thread.rs +++ b/tokio/src/runtime/scheduler/current_thread.rs @@ -155,8 +155,13 @@ impl CurrentThread { #[track_caller] pub(crate) fn block_on(&self, future: F) -> F::Output { + use crate::runtime::scheduler; + pin!(future); + let handle = scheduler::Handle::CurrentThread(self.handle.clone()); + let mut enter = crate::runtime::enter_runtime(&handle, false); + // Attempt to steal the scheduler core and block_on the future if we can // there, otherwise, lets select on a notification that the core is // available or the future is complete. @@ -164,12 +169,11 @@ impl CurrentThread { if let Some(core) = self.take_core() { return core.block_on(future); } else { - let mut enter = crate::runtime::enter(false); - let notified = self.notify.notified(); pin!(notified); if let Some(out) = enter + .blocking .block_on(poll_fn(|cx| { if notified.as_mut().poll(cx).is_ready() { return Ready(None); @@ -522,7 +526,6 @@ impl CoreGuard<'_> { #[track_caller] fn block_on(self, future: F) -> F::Output { let ret = self.enter(|mut core, context| { - let _enter = crate::runtime::enter(false); let waker = Handle::waker_ref(&context.handle); let mut cx = std::task::Context::from_waker(&waker); diff --git a/tokio/src/runtime/scheduler/multi_thread/mod.rs b/tokio/src/runtime/scheduler/multi_thread/mod.rs index 8a1b5b6c483..2b96901cce0 100644 --- a/tokio/src/runtime/scheduler/multi_thread/mod.rs +++ b/tokio/src/runtime/scheduler/multi_thread/mod.rs @@ -20,7 +20,7 @@ use crate::loom::sync::Arc; use crate::runtime::{ blocking, driver::{self, Driver}, - Config, + scheduler, Config, }; use crate::util::RngSeedGenerator; @@ -73,8 +73,12 @@ impl MultiThread { where F: Future, { - let mut enter = crate::runtime::enter(true); - enter.block_on(future).expect("failed to park thread") + let handle = scheduler::Handle::MultiThread(self.handle.clone()); + let mut enter = crate::runtime::enter_runtime(&handle, true); + enter + .blocking + .block_on(future) + .expect("failed to park thread") } } diff --git a/tokio/src/runtime/scheduler/multi_thread/worker.rs b/tokio/src/runtime/scheduler/multi_thread/worker.rs index b21a11923d9..e1bde3b4251 100644 --- a/tokio/src/runtime/scheduler/multi_thread/worker.rs +++ b/tokio/src/runtime/scheduler/multi_thread/worker.rs @@ -62,7 +62,7 @@ use crate::runtime::enter::EnterContext; use crate::runtime::scheduler::multi_thread::{queue, Handle, Idle, Parker, Unparker}; use crate::runtime::task::{Inject, OwnedTasks}; use crate::runtime::{ - blocking, coop, driver, task, Config, MetricsBatch, SchedulerMetrics, WorkerMetrics, + blocking, coop, driver, scheduler, task, Config, MetricsBatch, SchedulerMetrics, WorkerMetrics, }; use crate::util::atomic_cell::AtomicCell; use crate::util::rand::{FastRand, RngSeedGenerator}; @@ -350,7 +350,7 @@ where // constrained by task budgets. let _reset = Reset(coop::stop()); - crate::runtime::enter::exit(f) + crate::runtime::enter::exit_runtime(f) } else { f() } @@ -372,14 +372,15 @@ fn run(worker: Arc) { None => return, }; + let handle = scheduler::Handle::MultiThread(worker.handle.clone()); + let _enter = crate::runtime::enter_runtime(&handle, true); + // Set the worker context. let cx = Context { worker, core: RefCell::new(None), }; - let _enter = crate::runtime::enter(true); - CURRENT.set(&cx, || { // This should always be an error. It only returns a `Result` to support // using `?` to short circuit.