diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index 50c66657cc4..4a85d0ddfb2 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -888,7 +888,7 @@ impl Builder { // there are no futures ready to do something, it'll let the timer or // the reactor to generate some new stimuli for the futures to continue // in their life. - let scheduler = CurrentThread::new( + let (scheduler, handle) = CurrentThread::new( driver, driver_handle, blocking_spawner, @@ -906,7 +906,7 @@ impl Builder { ); let handle = Handle { - inner: scheduler::Handle::CurrentThread(scheduler.handle().clone()), + inner: scheduler::Handle::CurrentThread(handle), }; Ok(Runtime::from_parts( @@ -1009,7 +1009,7 @@ cfg_rt_multi_thread! { let seed_generator_1 = self.seed_generator.next_generator(); let seed_generator_2 = self.seed_generator.next_generator(); - let (scheduler, launch) = MultiThread::new( + let (scheduler, handle, launch) = MultiThread::new( core_threads, driver, driver_handle, @@ -1027,8 +1027,7 @@ cfg_rt_multi_thread! { }, ); - let handle = scheduler::Handle::MultiThread(scheduler.handle().clone()); - let handle = Handle { inner: handle }; + let handle = Handle { inner: scheduler::Handle::MultiThread(handle) }; // Spawn the thread pool workers let _enter = handle.enter(); diff --git a/tokio/src/runtime/runtime.rs b/tokio/src/runtime/runtime.rs index fad0afa2211..3d4fd67884d 100644 --- a/tokio/src/runtime/runtime.rs +++ b/tokio/src/runtime/runtime.rs @@ -276,9 +276,9 @@ impl Runtime { let _enter = self.enter(); match &self.scheduler { - Scheduler::CurrentThread(exec) => exec.block_on(future), + Scheduler::CurrentThread(exec) => exec.block_on(&self.handle.inner, future), #[cfg(all(feature = "rt-multi-thread", not(tokio_wasi)))] - Scheduler::MultiThread(exec) => exec.block_on(future), + Scheduler::MultiThread(exec) => exec.block_on(&self.handle.inner, future), } } @@ -397,20 +397,14 @@ impl Drop for Runtime { Scheduler::CurrentThread(current_thread) => { // This ensures that tasks spawned on the current-thread // runtime are dropped inside the runtime's context. - match context::try_set_current(&self.handle.inner) { - Some(guard) => current_thread.set_context_guard(guard), - None => { - // The context thread-local has already been destroyed. - // - // We don't set the guard in this case. Calls to tokio::spawn in task - // destructors would fail regardless if this happens. - } - } + let _guard = context::try_set_current(&self.handle.inner); + current_thread.shutdown(&self.handle.inner); } #[cfg(all(feature = "rt-multi-thread", not(tokio_wasi)))] - Scheduler::MultiThread(_) => { + Scheduler::MultiThread(multi_thread) => { // The threaded scheduler drops its tasks on its worker threads, which is // already in the runtime's context. + multi_thread.shutdown(&self.handle.inner); } } } diff --git a/tokio/src/runtime/scheduler/current_thread.rs b/tokio/src/runtime/scheduler/current_thread.rs index 9297687023f..694bc8dae95 100644 --- a/tokio/src/runtime/scheduler/current_thread.rs +++ b/tokio/src/runtime/scheduler/current_thread.rs @@ -1,10 +1,9 @@ use crate::future::poll_fn; use crate::loom::sync::atomic::AtomicBool; use crate::loom::sync::{Arc, Mutex}; -use crate::runtime::context::SetCurrentGuard; use crate::runtime::driver::{self, Driver}; use crate::runtime::task::{self, JoinHandle, OwnedTasks, Schedule, Task}; -use crate::runtime::{blocking, Config}; +use crate::runtime::{blocking, scheduler, Config}; use crate::runtime::{MetricsBatch, SchedulerMetrics, WorkerMetrics}; use crate::sync::notify::Notify; use crate::util::atomic_cell::AtomicCell; @@ -26,15 +25,6 @@ pub(crate) struct CurrentThread { /// Notifier for waking up other threads to steal the /// driver. notify: Notify, - - /// Shared handle to the scheduler - handle: Arc, - - /// This is usually None, but right before dropping the CurrentThread - /// scheduler, it is changed to `Some` with the context being the runtime's - /// own context. This ensures that any tasks dropped in the `CurrentThread`'s - /// destructor run in that runtime's context. - context_guard: Option, } /// Handle to the current thread scheduler @@ -118,7 +108,7 @@ impl CurrentThread { blocking_spawner: blocking::Spawner, seed_generator: RngSeedGenerator, config: Config, - ) -> CurrentThread { + ) -> (CurrentThread, Arc) { let handle = Arc::new(Handle { shared: Shared { queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))), @@ -141,32 +131,26 @@ impl CurrentThread { unhandled_panic: false, }))); - CurrentThread { + let scheduler = CurrentThread { core, notify: Notify::new(), - handle, - context_guard: None, - } - } + }; - pub(crate) fn handle(&self) -> &Arc { - &self.handle + (scheduler, handle) } #[track_caller] - pub(crate) fn block_on(&self, future: F) -> F::Output { - use crate::runtime::scheduler; - + pub(crate) fn block_on(&self, handle: &scheduler::Handle, future: F) -> F::Output { pin!(future); - let handle = scheduler::Handle::CurrentThread(self.handle.clone()); - let mut enter = crate::runtime::enter_runtime(&handle, false); + let mut enter = crate::runtime::enter_runtime(handle, false); + let handle = handle.as_current_thread(); // Attempt to steal the scheduler core and block_on the future if we can // there, otherwise, lets select on a notification that the core is // available or the future is complete. loop { - if let Some(core) = self.take_core() { + if let Some(core) = self.take_core(handle) { return core.block_on(future); } else { let notified = self.notify.notified(); @@ -193,48 +177,44 @@ impl CurrentThread { } } - fn take_core(&self) -> Option> { + fn take_core(&self, handle: &Arc) -> Option> { let core = self.core.take()?; Some(CoreGuard { context: Context { - handle: self.handle.clone(), + handle: handle.clone(), core: RefCell::new(Some(core)), }, scheduler: self, }) } - pub(crate) fn set_context_guard(&mut self, guard: SetCurrentGuard) { - self.context_guard = Some(guard); - } -} + pub(crate) fn shutdown(&mut self, handle: &scheduler::Handle) { + let handle = handle.as_current_thread(); -impl Drop for CurrentThread { - fn drop(&mut self) { // Avoid a double panic if we are currently panicking and // the lock may be poisoned. - let core = match self.take_core() { + let core = match self.take_core(handle) { Some(core) => core, None if std::thread::panicking() => return, None => panic!("Oh no! We never placed the Core back, this is a bug!"), }; - core.enter(|mut core, context| { + core.enter(|mut core, _context| { // Drain the OwnedTasks collection. This call also closes the // collection, ensuring that no tasks are ever pushed after this // call returns. - context.handle.shared.owned.close_and_shutdown_all(); + handle.shared.owned.close_and_shutdown_all(); // Drain local queue // We already shut down every task, so we just need to drop the task. - while let Some(task) = core.pop_task(&self.handle) { + while let Some(task) = core.pop_task(handle) { drop(task); } // Drain remote queue and set it to None - let remote_queue = self.handle.shared.queue.lock().take(); + let remote_queue = handle.shared.queue.lock().take(); // Using `Option::take` to replace the shared queue with `None`. // We already shut down every task, so we just need to drop the task. @@ -244,14 +224,14 @@ impl Drop for CurrentThread { } } - assert!(context.handle.shared.owned.is_empty()); + assert!(handle.shared.owned.is_empty()); // Submit metrics - core.metrics.submit(&self.handle.shared.worker_metrics); + core.metrics.submit(&handle.shared.worker_metrics); // Shutdown the resource drivers if let Some(driver) = core.driver.as_mut() { - driver.shutdown(&self.handle.driver); + driver.shutdown(&handle.driver); } (core, ()) @@ -299,10 +279,10 @@ impl Context { /// Blocks the current thread until an event is received by the driver, /// including I/O events, timer events, ... - fn park(&self, mut core: Box) -> Box { + fn park(&self, mut core: Box, handle: &Handle) -> Box { let mut driver = core.driver.take().expect("driver missing"); - if let Some(f) = &self.handle.shared.config.before_park { + if let Some(f) = &handle.shared.config.before_park { // Incorrect lint, the closures are actually different types so `f` // cannot be passed as an argument to `enter`. #[allow(clippy::redundant_closure)] @@ -315,17 +295,17 @@ impl Context { if core.tasks.is_empty() { // Park until the thread is signaled core.metrics.about_to_park(); - core.metrics.submit(&self.handle.shared.worker_metrics); + core.metrics.submit(&handle.shared.worker_metrics); let (c, _) = self.enter(core, || { - driver.park(&self.handle.driver); + driver.park(&handle.driver); }); core = c; core.metrics.returned_from_park(); } - if let Some(f) = &self.handle.shared.config.after_unpark { + if let Some(f) = &handle.shared.config.after_unpark { // Incorrect lint, the closures are actually different types so `f` // cannot be passed as an argument to `enter`. #[allow(clippy::redundant_closure)] @@ -338,12 +318,12 @@ impl Context { } /// Checks the driver for new events without blocking the thread. - fn park_yield(&self, mut core: Box) -> Box { + fn park_yield(&self, mut core: Box, handle: &Handle) -> Box { let mut driver = core.driver.take().expect("driver missing"); - core.metrics.submit(&self.handle.shared.worker_metrics); + core.metrics.submit(&handle.shared.worker_metrics); let (mut core, _) = self.enter(core, || { - driver.park_timeout(&self.handle.driver, Duration::from_millis(0)); + driver.park_timeout(&handle.driver, Duration::from_millis(0)); }); core.driver = Some(driver); @@ -577,7 +557,7 @@ impl CoreGuard<'_> { let task = match entry { Some(entry) => entry, None => { - core = context.park(core); + core = context.park(core, handle); // Try polling the `block_on` future next continue 'outer; @@ -595,7 +575,7 @@ impl CoreGuard<'_> { // Yield to the driver, this drives the timer and pulls any // pending I/O events. - core = context.park_yield(core); + core = context.park_yield(core, handle); } }); diff --git a/tokio/src/runtime/scheduler/mod.rs b/tokio/src/runtime/scheduler/mod.rs index b991c89f00c..f45d8a80ba4 100644 --- a/tokio/src/runtime/scheduler/mod.rs +++ b/tokio/src/runtime/scheduler/mod.rs @@ -97,6 +97,14 @@ cfg_rt! { Handle::MultiThread(h) => &h.seed_generator, } } + + pub(crate) fn as_current_thread(&self) -> &Arc { + match self { + Handle::CurrentThread(handle) => handle, + #[cfg(all(feature = "rt-multi-thread", not(tokio_wasi)))] + _ => panic!("not a CurrentThread handle"), + } + } } cfg_metrics! { diff --git a/tokio/src/runtime/scheduler/multi_thread/mod.rs b/tokio/src/runtime/scheduler/multi_thread/mod.rs index 2b96901cce0..11a4667ea3f 100644 --- a/tokio/src/runtime/scheduler/multi_thread/mod.rs +++ b/tokio/src/runtime/scheduler/multi_thread/mod.rs @@ -28,9 +28,7 @@ use std::fmt; use std::future::Future; /// Work-stealing based thread pool for executing futures. -pub(crate) struct MultiThread { - handle: Arc, -} +pub(crate) struct MultiThread; // ===== impl MultiThread ===== @@ -42,7 +40,7 @@ impl MultiThread { blocking_spawner: blocking::Spawner, seed_generator: RngSeedGenerator, config: Config, - ) -> (MultiThread, Launch) { + ) -> (MultiThread, Arc, Launch) { let parker = Parker::new(driver); let (handle, launch) = worker::create( size, @@ -52,34 +50,31 @@ impl MultiThread { seed_generator, config, ); - let multi_thread = MultiThread { handle }; - (multi_thread, launch) - } - - /// Returns reference to `Spawner`. - /// - /// The `Spawner` handle can be cloned and enables spawning tasks from other - /// threads. - pub(crate) fn handle(&self) -> &Arc { - &self.handle + (MultiThread, handle, launch) } /// Blocks the current thread waiting for the future to complete. /// /// The future will execute on the current thread, but all spawned tasks /// will be executed on the thread pool. - pub(crate) fn block_on(&self, future: F) -> F::Output + pub(crate) fn block_on(&self, handle: &scheduler::Handle, future: F) -> F::Output where F: Future, { - let handle = scheduler::Handle::MultiThread(self.handle.clone()); - let mut enter = crate::runtime::enter_runtime(&handle, true); + let mut enter = crate::runtime::enter_runtime(handle, true); enter .blocking .block_on(future) .expect("failed to park thread") } + + pub(crate) fn shutdown(&mut self, handle: &scheduler::Handle) { + match handle { + scheduler::Handle::MultiThread(handle) => handle.shutdown(), + _ => panic!("expected MultiThread scheduler"), + } + } } impl fmt::Debug for MultiThread { @@ -87,9 +82,3 @@ impl fmt::Debug for MultiThread { fmt.debug_struct("MultiThread").finish() } } - -impl Drop for MultiThread { - fn drop(&mut self) { - self.handle.shutdown(); - } -}