diff --git a/tokio-util/src/task/join_map.rs b/tokio-util/src/task/join_map.rs
index d5b41e437f5..c6bf5bc241a 100644
--- a/tokio-util/src/task/join_map.rs
+++ b/tokio-util/src/task/join_map.rs
@@ -363,10 +363,7 @@ where
fn insert(&mut self, key: K, abort: AbortHandle) {
let hash = self.hash(&key);
let id = abort.id();
- let map_key = Key {
- id: id.clone(),
- key,
- };
+ let map_key = Key { id, key };
// Insert the new key into the map of tasks by keys.
let entry = self
diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs
index 6de657481e0..c3ef6469926 100644
--- a/tokio/src/runtime/context.rs
+++ b/tokio/src/runtime/context.rs
@@ -7,6 +7,7 @@ use crate::util::rand::{FastRand, RngSeed};
cfg_rt! {
use crate::runtime::scheduler;
+ use crate::runtime::task::Id;
use std::cell::RefCell;
use std::marker::PhantomData;
@@ -17,6 +18,8 @@ struct Context {
/// Handle to the runtime scheduler running on the current thread.
#[cfg(feature = "rt")]
handle: RefCell>,
+ #[cfg(feature = "rt")]
+ current_task_id: Cell >,
/// Tracks if the current thread is currently driving a runtime.
/// Note, that if this is set to "entered", the current scheduler
@@ -41,6 +44,8 @@ tokio_thread_local! {
/// accessing drivers, etc...
#[cfg(feature = "rt")]
handle: RefCell::new(None),
+ #[cfg(feature = "rt")]
+ current_task_id: Cell::new(None),
/// Tracks if the current thread is currently driving a runtime.
/// Note, that if this is set to "entered", the current scheduler
@@ -107,6 +112,14 @@ cfg_rt! {
pub(crate) struct DisallowBlockInPlaceGuard(bool);
+ pub(crate) fn set_current_task_id(id: Option) -> Option {
+ CONTEXT.try_with(|ctx| ctx.current_task_id.replace(id)).unwrap_or(None)
+ }
+
+ pub(crate) fn current_task_id() -> Option {
+ CONTEXT.try_with(|ctx| ctx.current_task_id.get()).unwrap_or(None)
+ }
+
pub(crate) fn try_current() -> Result {
match CONTEXT.try_with(|ctx| ctx.handle.borrow().clone()) {
Ok(Some(handle)) => Ok(handle),
diff --git a/tokio/src/runtime/task/abort.rs b/tokio/src/runtime/task/abort.rs
index c34e2bb9a02..bfdf53c5105 100644
--- a/tokio/src/runtime/task/abort.rs
+++ b/tokio/src/runtime/task/abort.rs
@@ -67,7 +67,7 @@ impl AbortHandle {
#[cfg(tokio_unstable)]
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
pub fn id(&self) -> super::Id {
- self.id.clone()
+ self.id
}
}
diff --git a/tokio/src/runtime/task/core.rs b/tokio/src/runtime/task/core.rs
index 3e07d7c97fd..c4a7c6c72e3 100644
--- a/tokio/src/runtime/task/core.rs
+++ b/tokio/src/runtime/task/core.rs
@@ -11,6 +11,7 @@
use crate::future::Future;
use crate::loom::cell::UnsafeCell;
+use crate::runtime::context;
use crate::runtime::task::raw::{self, Vtable};
use crate::runtime::task::state::State;
use crate::runtime::task::{Id, Schedule};
@@ -157,6 +158,26 @@ impl CoreStage {
}
}
+/// Set and clear the task id in the context when the future is executed or
+/// dropped, or when the output produced by the future is dropped.
+pub(crate) struct TaskIdGuard {
+ parent_task_id: Option,
+}
+
+impl TaskIdGuard {
+ fn enter(id: Id) -> Self {
+ TaskIdGuard {
+ parent_task_id: context::set_current_task_id(Some(id)),
+ }
+ }
+}
+
+impl Drop for TaskIdGuard {
+ fn drop(&mut self) {
+ context::set_current_task_id(self.parent_task_id);
+ }
+}
+
impl Core {
/// Polls the future.
///
@@ -183,6 +204,7 @@ impl Core {
// Safety: The caller ensures the future is pinned.
let future = unsafe { Pin::new_unchecked(future) };
+ let _guard = TaskIdGuard::enter(self.task_id);
future.poll(&mut cx)
})
};
@@ -236,6 +258,7 @@ impl Core {
}
unsafe fn set_stage(&self, stage: Stage) {
+ let _guard = TaskIdGuard::enter(self.task_id);
self.stage.stage.with_mut(|ptr| *ptr = stage)
}
}
diff --git a/tokio/src/runtime/task/error.rs b/tokio/src/runtime/task/error.rs
index 7cf602abd33..f7ead77b7cc 100644
--- a/tokio/src/runtime/task/error.rs
+++ b/tokio/src/runtime/task/error.rs
@@ -128,7 +128,7 @@ impl JoinError {
#[cfg(tokio_unstable)]
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
pub fn id(&self) -> Id {
- self.id.clone()
+ self.id
}
}
diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs
index 545b01b7ff7..085aebe92ac 100644
--- a/tokio/src/runtime/task/harness.rs
+++ b/tokio/src/runtime/task/harness.rs
@@ -454,13 +454,12 @@ fn cancel_task(core: &Core) {
core.drop_future_or_output();
}));
- let id = core.task_id.clone();
match res {
Ok(()) => {
- core.store_output(Err(JoinError::cancelled(id)));
+ core.store_output(Err(JoinError::cancelled(core.task_id)));
}
Err(panic) => {
- core.store_output(Err(JoinError::panic(id, panic)));
+ core.store_output(Err(JoinError::panic(core.task_id, panic)));
}
}
}
@@ -492,7 +491,7 @@ fn poll_future(core: &Core, cx: Context<'_>) -> Po
Ok(Poll::Ready(output)) => Ok(output),
Err(panic) => {
core.scheduler.unhandled_panic();
- Err(JoinError::panic(core.task_id.clone(), panic))
+ Err(JoinError::panic(core.task_id, panic))
}
};
diff --git a/tokio/src/runtime/task/join.rs b/tokio/src/runtime/task/join.rs
index 31f6a6f8748..21ef2f1ba8b 100644
--- a/tokio/src/runtime/task/join.rs
+++ b/tokio/src/runtime/task/join.rs
@@ -267,7 +267,7 @@ impl JoinHandle {
raw.ref_inc();
raw
});
- super::AbortHandle::new(raw, self.id.clone())
+ super::AbortHandle::new(raw, self.id)
}
/// Returns a [task ID] that uniquely identifies this task relative to other
@@ -282,7 +282,7 @@ impl JoinHandle {
#[cfg(tokio_unstable)]
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
pub fn id(&self) -> super::Id {
- self.id.clone()
+ self.id
}
}
diff --git a/tokio/src/runtime/task/mod.rs b/tokio/src/runtime/task/mod.rs
index 3d5b1cbf373..c31b110a476 100644
--- a/tokio/src/runtime/task/mod.rs
+++ b/tokio/src/runtime/task/mod.rs
@@ -139,6 +139,8 @@
// unstable. This should be removed once `JoinSet` is stabilized.
#![cfg_attr(not(tokio_unstable), allow(dead_code))]
+use crate::runtime::context;
+
mod core;
use self::core::Cell;
use self::core::Header;
@@ -193,6 +195,10 @@ use std::{fmt, mem};
/// task completes, the same ID may be used for another task.
/// - Task IDs are *not* sequential, and do not indicate the order in which
/// tasks are spawned, what runtime a task is spawned on, or any other data.
+/// - The task ID of the currently running task can be obtained from inside the
+/// task via the [`task::try_id()`](crate::task::try_id()) and
+/// [`task::id()`](crate::task::id()) functions and from outside the task via
+/// the [`JoinHandle::id()`](crate::task::JoinHandle::id()) function.
///
/// **Note**: This is an [unstable API][unstable]. The public API of this type
/// may break in 1.x releases. See [the documentation on unstable
@@ -201,10 +207,49 @@ use std::{fmt, mem};
/// [unstable]: crate#unstable-features
#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
-// TODO(eliza): there's almost certainly no reason not to make this `Copy` as well...
-#[derive(Clone, Debug, Hash, Eq, PartialEq)]
+#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
pub struct Id(u64);
+/// Returns the [`Id`] of the currently running task.
+///
+/// # Panics
+///
+/// This function panics if called from outside a task. Please note that calls
+/// to `block_on` do not have task IDs, so the method will panic if called from
+/// within a call to `block_on`. For a version of this function that doesn't
+/// panic, see [`task::try_id()`](crate::runtime::task::try_id()).
+///
+/// **Note**: This is an [unstable API][unstable]. The public API of this type
+/// may break in 1.x releases. See [the documentation on unstable
+/// features][unstable] for details.
+///
+/// [task ID]: crate::task::Id
+/// [unstable]: crate#unstable-features
+#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
+#[track_caller]
+pub fn id() -> Id {
+ context::current_task_id().expect("Can't get a task id when not inside a task")
+}
+
+/// Returns the [`Id`] of the currently running task, or `None` if called outside
+/// of a task.
+///
+/// This function is similar to [`task::id()`](crate::runtime::task::id()), except
+/// that it returns `None` rather than panicking if called outside of a task
+/// context.
+///
+/// **Note**: This is an [unstable API][unstable]. The public API of this type
+/// may break in 1.x releases. See [the documentation on unstable
+/// features][unstable] for details.
+///
+/// [task ID]: crate::task::Id
+/// [unstable]: crate#unstable-features
+#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
+#[track_caller]
+pub fn try_id() -> Option {
+ context::current_task_id()
+}
+
/// An owned handle to the task, tracked by ref count.
#[repr(transparent)]
pub(crate) struct Task {
@@ -284,7 +329,7 @@ cfg_rt! {
T: Future + 'static,
T::Output: 'static,
{
- let raw = RawTask::new::(task, scheduler, id.clone());
+ let raw = RawTask::new::(task, scheduler, id);
let task = Task {
raw,
_p: PhantomData,
diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs
index f1683f7e07f..9b753701854 100644
--- a/tokio/src/task/mod.rs
+++ b/tokio/src/task/mod.rs
@@ -318,7 +318,7 @@ cfg_rt! {
pub mod join_set;
cfg_unstable! {
- pub use crate::runtime::task::Id;
+ pub use crate::runtime::task::{Id, id, try_id};
}
cfg_trace! {
diff --git a/tokio/tests/task_id.rs b/tokio/tests/task_id.rs
new file mode 100644
index 00000000000..d7b7c0cd812
--- /dev/null
+++ b/tokio/tests/task_id.rs
@@ -0,0 +1,303 @@
+#![warn(rust_2018_idioms)]
+#![allow(clippy::declare_interior_mutable_const)]
+#![cfg(all(feature = "full", tokio_unstable))]
+
+#[cfg(not(tokio_wasi))]
+use std::error::Error;
+use std::future::Future;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+#[cfg(not(tokio_wasi))]
+use tokio::runtime::{Builder, Runtime};
+use tokio::sync::oneshot;
+use tokio::task::{self, Id, LocalSet};
+
+#[cfg(not(tokio_wasi))]
+mod support {
+ pub mod panic;
+}
+#[cfg(not(tokio_wasi))]
+use support::panic::test_panic;
+
+#[tokio::test(flavor = "current_thread")]
+async fn task_id_spawn() {
+ tokio::spawn(async { println!("task id: {}", task::id()) })
+ .await
+ .unwrap();
+}
+
+#[cfg(not(tokio_wasi))]
+#[tokio::test(flavor = "current_thread")]
+async fn task_id_spawn_blocking() {
+ task::spawn_blocking(|| println!("task id: {}", task::id()))
+ .await
+ .unwrap();
+}
+
+#[tokio::test(flavor = "current_thread")]
+async fn task_id_collision_current_thread() {
+ let handle1 = tokio::spawn(async { task::id() });
+ let handle2 = tokio::spawn(async { task::id() });
+
+ let (id1, id2) = tokio::join!(handle1, handle2);
+ assert_ne!(id1.unwrap(), id2.unwrap());
+}
+
+#[cfg(not(tokio_wasi))]
+#[tokio::test(flavor = "multi_thread")]
+async fn task_id_collision_multi_thread() {
+ let handle1 = tokio::spawn(async { task::id() });
+ let handle2 = tokio::spawn(async { task::id() });
+
+ let (id1, id2) = tokio::join!(handle1, handle2);
+ assert_ne!(id1.unwrap(), id2.unwrap());
+}
+
+#[tokio::test(flavor = "current_thread")]
+async fn task_ids_match_current_thread() {
+ let (tx, rx) = oneshot::channel();
+ let handle = tokio::spawn(async {
+ let id = rx.await.unwrap();
+ assert_eq!(id, task::id());
+ });
+ tx.send(handle.id()).unwrap();
+ handle.await.unwrap();
+}
+
+#[cfg(not(tokio_wasi))]
+#[tokio::test(flavor = "multi_thread")]
+async fn task_ids_match_multi_thread() {
+ let (tx, rx) = oneshot::channel();
+ let handle = tokio::spawn(async {
+ let id = rx.await.unwrap();
+ assert_eq!(id, task::id());
+ });
+ tx.send(handle.id()).unwrap();
+ handle.await.unwrap();
+}
+
+#[cfg(not(tokio_wasi))]
+#[tokio::test(flavor = "multi_thread")]
+async fn task_id_future_destructor_completion() {
+ struct MyFuture {
+ tx: Option>,
+ }
+
+ impl Future for MyFuture {
+ type Output = ();
+
+ fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
+ Poll::Ready(())
+ }
+ }
+
+ impl Drop for MyFuture {
+ fn drop(&mut self) {
+ let _ = self.tx.take().unwrap().send(task::id());
+ }
+ }
+
+ let (tx, rx) = oneshot::channel();
+ let handle = tokio::spawn(MyFuture { tx: Some(tx) });
+ let id = handle.id();
+ handle.await.unwrap();
+ assert_eq!(rx.await.unwrap(), id);
+}
+
+#[cfg(not(tokio_wasi))]
+#[tokio::test(flavor = "multi_thread")]
+async fn task_id_future_destructor_abort() {
+ struct MyFuture {
+ tx: Option>,
+ }
+
+ impl Future for MyFuture {
+ type Output = ();
+
+ fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
+ Poll::Pending
+ }
+ }
+ impl Drop for MyFuture {
+ fn drop(&mut self) {
+ let _ = self.tx.take().unwrap().send(task::id());
+ }
+ }
+
+ let (tx, rx) = oneshot::channel();
+ let handle = tokio::spawn(MyFuture { tx: Some(tx) });
+ let id = handle.id();
+ handle.abort();
+ assert!(handle.await.unwrap_err().is_cancelled());
+ assert_eq!(rx.await.unwrap(), id);
+}
+
+#[tokio::test(flavor = "current_thread")]
+async fn task_id_output_destructor_handle_dropped_before_completion() {
+ struct MyOutput {
+ tx: Option>,
+ }
+
+ impl Drop for MyOutput {
+ fn drop(&mut self) {
+ let _ = self.tx.take().unwrap().send(task::id());
+ }
+ }
+
+ struct MyFuture {
+ tx: Option>,
+ }
+
+ impl Future for MyFuture {
+ type Output = MyOutput;
+
+ fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll {
+ Poll::Ready(MyOutput { tx: self.tx.take() })
+ }
+ }
+
+ let (tx, mut rx) = oneshot::channel();
+ let handle = tokio::spawn(MyFuture { tx: Some(tx) });
+ let id = handle.id();
+ drop(handle);
+ assert!(rx.try_recv().is_err());
+ assert_eq!(rx.await.unwrap(), id);
+}
+
+#[tokio::test(flavor = "current_thread")]
+async fn task_id_output_destructor_handle_dropped_after_completion() {
+ struct MyOutput {
+ tx: Option>,
+ }
+
+ impl Drop for MyOutput {
+ fn drop(&mut self) {
+ let _ = self.tx.take().unwrap().send(task::id());
+ }
+ }
+
+ struct MyFuture {
+ tx_output: Option>,
+ tx_future: Option>,
+ }
+
+ impl Future for MyFuture {
+ type Output = MyOutput;
+
+ fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll {
+ let _ = self.tx_future.take().unwrap().send(());
+ Poll::Ready(MyOutput {
+ tx: self.tx_output.take(),
+ })
+ }
+ }
+
+ let (tx_output, mut rx_output) = oneshot::channel();
+ let (tx_future, rx_future) = oneshot::channel();
+ let handle = tokio::spawn(MyFuture {
+ tx_output: Some(tx_output),
+ tx_future: Some(tx_future),
+ });
+ let id = handle.id();
+ rx_future.await.unwrap();
+ assert!(rx_output.try_recv().is_err());
+ drop(handle);
+ assert_eq!(rx_output.await.unwrap(), id);
+}
+
+#[test]
+fn task_try_id_outside_task() {
+ assert_eq!(None, task::try_id());
+}
+
+#[cfg(not(tokio_wasi))]
+#[test]
+fn task_try_id_inside_block_on() {
+ let rt = Runtime::new().unwrap();
+ rt.block_on(async {
+ assert_eq!(None, task::try_id());
+ });
+}
+
+#[tokio::test(flavor = "current_thread")]
+async fn task_id_spawn_local() {
+ LocalSet::new()
+ .run_until(async {
+ task::spawn_local(async { println!("task id: {}", task::id()) })
+ .await
+ .unwrap();
+ })
+ .await
+}
+
+#[tokio::test(flavor = "current_thread")]
+async fn task_id_nested_spawn_local() {
+ LocalSet::new()
+ .run_until(async {
+ task::spawn_local(async {
+ let parent_id = task::id();
+ LocalSet::new()
+ .run_until(async {
+ task::spawn_local(async move {
+ assert_ne!(parent_id, task::id());
+ })
+ .await
+ .unwrap();
+ })
+ .await;
+ assert_eq!(parent_id, task::id());
+ })
+ .await
+ .unwrap();
+ })
+ .await;
+}
+
+#[cfg(not(tokio_wasi))]
+#[tokio::test(flavor = "multi_thread")]
+async fn task_id_block_in_place_block_on_spawn() {
+ task::spawn(async {
+ let parent_id = task::id();
+
+ task::block_in_place(move || {
+ let rt = Builder::new_current_thread().build().unwrap();
+ rt.block_on(rt.spawn(async move {
+ assert_ne!(parent_id, task::id());
+ }))
+ .unwrap();
+ });
+
+ assert_eq!(parent_id, task::id());
+ })
+ .await
+ .unwrap();
+}
+
+#[cfg(not(tokio_wasi))]
+#[test]
+fn task_id_outside_task_panic_caller() -> Result<(), Box> {
+ let panic_location_file = test_panic(|| {
+ let _ = task::id();
+ });
+
+ // The panic location should be in this file
+ assert_eq!(&panic_location_file.unwrap(), file!());
+
+ Ok(())
+}
+
+#[cfg(not(tokio_wasi))]
+#[test]
+fn task_id_inside_block_on_panic_caller() -> Result<(), Box> {
+ let panic_location_file = test_panic(|| {
+ let rt = Runtime::new().unwrap();
+ rt.block_on(async {
+ task::id();
+ });
+ });
+
+ // The panic location should be in this file
+ assert_eq!(&panic_location_file.unwrap(), file!());
+
+ Ok(())
+}