Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rt: panic if EnterGuard dropped incorrect order #5772

Merged
merged 5 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 2 additions & 3 deletions tokio/src/runtime/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ cfg_rt! {

use crate::runtime::{scheduler, task::Id};

use std::cell::RefCell;
use std::task::Waker;

cfg_taskdump! {
Expand All @@ -41,7 +40,7 @@ struct Context {

/// Handle to the runtime scheduler running on the current thread.
#[cfg(feature = "rt")]
handle: RefCell<Option<scheduler::Handle>>,
current: current::HandleCell,

/// Handle to the scheduler's internal "context"
#[cfg(feature = "rt")]
Expand Down Expand Up @@ -84,7 +83,7 @@ tokio_thread_local! {
/// Tracks the current runtime handle to use when spawning,
/// accessing drivers, etc...
#[cfg(feature = "rt")]
handle: RefCell::new(None),
current: current::HandleCell::new(),

/// Tracks the current scheduler internal context
#[cfg(feature = "rt")]
Expand Down
70 changes: 55 additions & 15 deletions tokio/src/runtime/context/current.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,42 @@ use super::{Context, CONTEXT};
use crate::runtime::{scheduler, TryCurrentError};
use crate::util::markers::SyncNotSend;

use std::cell::{Cell, RefCell};
use std::marker::PhantomData;

#[derive(Debug)]
#[must_use]
pub(crate) struct SetCurrentGuard {
old_handle: Option<scheduler::Handle>,
// The previous handle
prev: Option<scheduler::Handle>,

// The depth for this guard
depth: u16,

// Don't let the type move across threads.
_p: PhantomData<SyncNotSend>,
}

pub(super) struct HandleCell {
/// Current handle
handle: RefCell<Option<scheduler::Handle>>,

/// Tracks the number of nested calls to `try_set_current`.
depth: Cell<u16>,
}
carllerche marked this conversation as resolved.
Show resolved Hide resolved

/// Sets this [`Handle`] as the current active [`Handle`].
///
/// [`Handle`]: crate::runtime::scheduler::Handle
pub(crate) fn try_set_current(handle: &scheduler::Handle) -> Option<SetCurrentGuard> {
CONTEXT
.try_with(|ctx| {
let old_handle = ctx.handle.borrow_mut().replace(handle.clone());

SetCurrentGuard {
old_handle,
_p: PhantomData,
}
})
.ok()
CONTEXT.try_with(|ctx| ctx.set_current(handle)).ok()
}

pub(crate) fn with_current<F, R>(f: F) -> Result<R, TryCurrentError>
where
F: FnOnce(&scheduler::Handle) -> R,
{
match CONTEXT.try_with(|ctx| ctx.handle.borrow().as_ref().map(f)) {
match CONTEXT.try_with(|ctx| ctx.current.handle.borrow().as_ref().map(f)) {
Ok(Some(ret)) => Ok(ret),
Ok(None) => Err(TryCurrentError::new_no_context()),
Err(_access_error) => Err(TryCurrentError::new_thread_local_destroyed()),
Expand All @@ -41,19 +47,53 @@ where

impl Context {
pub(super) fn set_current(&self, handle: &scheduler::Handle) -> SetCurrentGuard {
let old_handle = self.handle.borrow_mut().replace(handle.clone());
let old_handle = self.current.handle.borrow_mut().replace(handle.clone());
let depth = self.current.depth.get();

if depth == u16::MAX {
panic!("reached max `enter` depth");
}

let depth = depth + 1;
self.current.depth.set(depth);

SetCurrentGuard {
old_handle,
prev: old_handle,
depth,
_p: PhantomData,
}
}
}

impl HandleCell {
pub(super) const fn new() -> HandleCell {
HandleCell {
handle: RefCell::new(None),
depth: Cell::new(0),
}
}
}

impl Drop for SetCurrentGuard {
fn drop(&mut self) {
CONTEXT.with(|ctx| {
*ctx.handle.borrow_mut() = self.old_handle.take();
let depth = ctx.current.depth.get();

if depth != self.depth {
if !std::thread::panicking() {
panic!(
"`EnterGuard` values dropped out of order. Guards returned by \
`tokio::runtime::Handle::enter()` must be dropped in the reverse \
order as they were acquired."
);
} else {
// Just return... this will leave handles in a wonky state though...
return;
}
}

*ctx.current.handle.borrow_mut() = self.prev.take();
ctx.current.depth.set(depth - 1);
});
}
}
41 changes: 38 additions & 3 deletions tokio/src/runtime/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,44 @@ pub struct EnterGuard<'a> {

impl Handle {
/// Enters the runtime context. This allows you to construct types that must
/// have an executor available on creation such as [`Sleep`] or [`TcpStream`].
/// It will also allow you to call methods such as [`tokio::spawn`] and [`Handle::current`]
/// without panicking.
/// have an executor available on creation such as [`Sleep`] or
/// [`TcpStream`]. It will also allow you to call methods such as
/// [`tokio::spawn`] and [`Handle::current`] without panicking.
///
/// # Panics
///
/// When calling `Handle::enter` multiple times, the returned guards
/// **must** be dropped in the reverse order that they were acquired.
/// Failure to do so will result in a panic and possible memory leaks.
///
/// # Examples
///
/// ```
/// use tokio::runtime::Runtime;
///
/// let rt = Runtime::new().unwrap();
///
/// let _guard = rt.enter();
/// tokio::spawn(async {
/// println!("Hello world!");
/// });
/// ```
///
/// Do **not** do the following, this shows a scenario that will result in a
/// panic and possible memory leak.
///
/// ```should_panic
/// use tokio::runtime::Runtime;
///
/// let rt1 = Runtime::new().unwrap();
/// let rt2 = Runtime::new().unwrap();
///
/// let enter1 = rt1.enter();
/// let enter2 = rt2.enter();
///
/// drop(enter1);
/// drop(enter2);
/// ```
///
/// [`Sleep`]: struct@crate::time::Sleep
/// [`TcpStream`]: struct@crate::net::TcpStream
Expand Down
66 changes: 66 additions & 0 deletions tokio/tests/rt_handle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use tokio::runtime::Runtime;

#[test]
fn basic_enter() {
let rt1 = rt();
let rt2 = rt();

let enter1 = rt1.enter();
let enter2 = rt2.enter();

drop(enter2);
drop(enter1);
}

#[test]
#[should_panic]
fn interleave_enter_different_rt() {
let rt1 = rt();
let rt2 = rt();

let enter1 = rt1.enter();
let enter2 = rt2.enter();

drop(enter1);
drop(enter2);
}

#[test]
#[should_panic]
fn interleave_enter_same_rt() {
let rt1 = rt();

let _enter1 = rt1.enter();
let enter2 = rt1.enter();
let enter3 = rt1.enter();

drop(enter2);
drop(enter3);
}

#[test]
fn interleave_then_enter() {
let _ = std::panic::catch_unwind(|| {
let rt1 = rt();
let rt2 = rt();

let enter1 = rt1.enter();
let enter2 = rt2.enter();

drop(enter1);
drop(enter2);
});

// Can still enter
let rt3 = rt();
let _enter = rt3.enter();
}

fn rt() -> Runtime {
tokio::runtime::Builder::new_current_thread()
.build()
.unwrap()
}