Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make task::Builder::spawn* methods fallible #4823

Merged
merged 3 commits into from Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion tokio/src/runtime/blocking/mod.rs
Expand Up @@ -4,7 +4,7 @@
//! compilation.

mod pool;
pub(crate) use pool::{spawn_blocking, BlockingPool, Mandatory, Spawner, Task};
pub(crate) use pool::{spawn_blocking, BlockingPool, Mandatory, SpawnError, Spawner, Task};

cfg_fs! {
pub(crate) use pool::spawn_mandatory_blocking;
Expand Down
26 changes: 23 additions & 3 deletions tokio/src/runtime/blocking/pool.rs
Expand Up @@ -11,6 +11,7 @@ use crate::runtime::{Builder, Callback, ToHandle};

use std::collections::{HashMap, VecDeque};
use std::fmt;
use std::io;
use std::time::Duration;

pub(crate) struct BlockingPool {
Expand Down Expand Up @@ -82,6 +83,25 @@ pub(crate) enum Mandatory {
NonMandatory,
}

pub(crate) enum SpawnError {
/// Pool is shutting down and the task was not scheduled
ShuttingDown,
/// There are no worker threads available to take the task
/// and the OS failed to spawn a new one
NoThreads(io::Error),
}

impl From<SpawnError> for io::Error {
fn from(e: SpawnError) -> Self {
match e {
SpawnError::ShuttingDown => {
io::Error::new(io::ErrorKind::Other, "blocking pool shutting down")
}
SpawnError::NoThreads(e) => e,
}
}
}

impl Task {
pub(crate) fn new(task: task::UnownedTask<NoopSchedule>, mandatory: Mandatory) -> Task {
Task { task, mandatory }
Expand Down Expand Up @@ -220,7 +240,7 @@ impl fmt::Debug for BlockingPool {
// ===== impl Spawner =====

impl Spawner {
pub(crate) fn spawn(&self, task: Task, rt: &dyn ToHandle) -> Result<(), ()> {
pub(crate) fn spawn(&self, task: Task, rt: &dyn ToHandle) -> Result<(), SpawnError> {
let mut shared = self.inner.shared.lock();

if shared.shutdown {
Expand All @@ -230,7 +250,7 @@ impl Spawner {
task.task.shutdown();

// no need to even push this task; it would never get picked up
return Err(());
return Err(SpawnError::ShuttingDown);
}

shared.queue.push_back(task);
Expand Down Expand Up @@ -261,7 +281,7 @@ impl Spawner {
Err(e) => {
// The OS refused to spawn the thread and there is no thread
// to pick up the task that has just been pushed to the queue.
panic!("OS can't spawn worker thread: {}", e)
return Err(SpawnError::NoThreads(e));
}
}
}
Expand Down
19 changes: 13 additions & 6 deletions tokio/src/runtime/handle.rs
Expand Up @@ -341,15 +341,22 @@ impl HandleInner {
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (join_handle, _was_spawned) = if cfg!(debug_assertions)
let (join_handle, spawn_result) = if cfg!(debug_assertions)
&& std::mem::size_of::<F>() > 2048
{
self.spawn_blocking_inner(Box::new(func), blocking::Mandatory::NonMandatory, None, rt)
} else {
self.spawn_blocking_inner(func, blocking::Mandatory::NonMandatory, None, rt)
};

join_handle
match spawn_result {
Ok(()) => join_handle,
// Compat: do not panic here, return the join_handle even though it will never resolve
Err(blocking::SpawnError::ShuttingDown) => join_handle,
Err(blocking::SpawnError::NoThreads(e)) => {
panic!("OS can't spawn worker thread: {}", e)
}
}
}

cfg_fs! {
Expand All @@ -363,7 +370,7 @@ impl HandleInner {
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (join_handle, was_spawned) = if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
let (join_handle, spawn_result) = if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
self.spawn_blocking_inner(
Box::new(func),
blocking::Mandatory::Mandatory,
Expand All @@ -379,7 +386,7 @@ impl HandleInner {
)
};

if was_spawned {
if spawn_result.is_ok() {
Some(join_handle)
} else {
None
Expand All @@ -394,7 +401,7 @@ impl HandleInner {
is_mandatory: blocking::Mandatory,
name: Option<&str>,
rt: &dyn ToHandle,
) -> (JoinHandle<R>, bool)
) -> (JoinHandle<R>, Result<(), blocking::SpawnError>)
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
Expand Down Expand Up @@ -424,7 +431,7 @@ impl HandleInner {
let spawned = self
.blocking_spawner
.spawn(blocking::Task::new(task, is_mandatory), rt);
(handle, spawned.is_ok())
(handle, spawned)
}
}

Expand Down
41 changes: 27 additions & 14 deletions tokio/src/task/builder.rs
Expand Up @@ -3,7 +3,7 @@ use crate::{
runtime::{context, Handle},
task::{JoinHandle, LocalSet},
};
use std::future::Future;
use std::{future::Future, io};

/// Factory which is used to configure the properties of a new task.
///
Expand Down Expand Up @@ -48,7 +48,7 @@ use std::future::Future;
/// .spawn(async move {
/// // Process each socket concurrently.
/// process(socket).await
/// });
/// })?;
/// }
/// }
/// ```
Expand Down Expand Up @@ -83,12 +83,12 @@ impl<'a> Builder<'a> {
/// See [`task::spawn`](crate::task::spawn) for
/// more details.
#[track_caller]
pub fn spawn<Fut>(self, future: Fut) -> JoinHandle<Fut::Output>
pub fn spawn<Fut>(self, future: Fut) -> io::Result<JoinHandle<Fut::Output>>
where
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
super::spawn::spawn_inner(future, self.name)
Ok(super::spawn::spawn_inner(future, self.name))
}

/// Spawn a task with this builder's settings on the provided [runtime
Expand All @@ -99,12 +99,16 @@ impl<'a> Builder<'a> {
/// [runtime handle]: crate::runtime::Handle
/// [`Handle::spawn`]: crate::runtime::Handle::spawn
#[track_caller]
pub fn spawn_on<Fut>(&mut self, future: Fut, handle: &Handle) -> JoinHandle<Fut::Output>
pub fn spawn_on<Fut>(
&mut self,
future: Fut,
handle: &Handle,
) -> io::Result<JoinHandle<Fut::Output>>
where
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
handle.spawn_named(future, self.name)
Ok(handle.spawn_named(future, self.name))
}

/// Spawns `!Send` a task on the current [`LocalSet`] with this builder's
Expand All @@ -122,12 +126,12 @@ impl<'a> Builder<'a> {
/// [`task::spawn_local`]: crate::task::spawn_local
/// [`LocalSet`]: crate::task::LocalSet
#[track_caller]
pub fn spawn_local<Fut>(self, future: Fut) -> JoinHandle<Fut::Output>
pub fn spawn_local<Fut>(self, future: Fut) -> io::Result<JoinHandle<Fut::Output>>
where
Fut: Future + 'static,
Fut::Output: 'static,
{
super::local::spawn_local_inner(future, self.name)
Ok(super::local::spawn_local_inner(future, self.name))
}

/// Spawns `!Send` a task on the provided [`LocalSet`] with this builder's
Expand All @@ -138,12 +142,16 @@ impl<'a> Builder<'a> {
/// [`LocalSet::spawn_local`]: crate::task::LocalSet::spawn_local
/// [`LocalSet`]: crate::task::LocalSet
#[track_caller]
pub fn spawn_local_on<Fut>(self, future: Fut, local_set: &LocalSet) -> JoinHandle<Fut::Output>
pub fn spawn_local_on<Fut>(
self,
future: Fut,
local_set: &LocalSet,
) -> io::Result<JoinHandle<Fut::Output>>
where
Fut: Future + 'static,
Fut::Output: 'static,
{
local_set.spawn_named(future, self.name)
Ok(local_set.spawn_named(future, self.name))
}

/// Spawns blocking code on the blocking threadpool.
Expand All @@ -155,7 +163,10 @@ impl<'a> Builder<'a> {
/// See [`task::spawn_blocking`](crate::task::spawn_blocking)
/// for more details.
#[track_caller]
pub fn spawn_blocking<Function, Output>(self, function: Function) -> JoinHandle<Output>
pub fn spawn_blocking<Function, Output>(
self,
function: Function,
) -> io::Result<JoinHandle<Output>>
where
Function: FnOnce() -> Output + Send + 'static,
Output: Send + 'static,
Expand All @@ -174,18 +185,20 @@ impl<'a> Builder<'a> {
self,
function: Function,
handle: &Handle,
) -> JoinHandle<Output>
) -> io::Result<JoinHandle<Output>>
where
Function: FnOnce() -> Output + Send + 'static,
Output: Send + 'static,
{
use crate::runtime::Mandatory;
let (join_handle, _was_spawned) = handle.as_inner().spawn_blocking_inner(
let (join_handle, spawn_result) = handle.as_inner().spawn_blocking_inner(
function,
Mandatory::NonMandatory,
self.name,
handle,
);
join_handle

spawn_result?;
Ok(join_handle)
}
}
25 changes: 14 additions & 11 deletions tokio/src/task/join_set.rs
Expand Up @@ -101,13 +101,15 @@ impl<T: 'static> JoinSet<T> {
/// use tokio::task::JoinSet;
///
/// #[tokio::main]
/// async fn main() {
/// async fn main() -> std::io::Result<()> {
/// let mut set = JoinSet::new();
///
/// // Use the builder to configure a task's name before spawning it.
/// set.build_task()
/// .name("my_task")
/// .spawn(async { /* ... */ });
/// .spawn(async { /* ... */ })?;
///
/// Ok(())
/// }
/// ```
#[cfg(all(tokio_unstable, feature = "tracing"))]
Expand Down Expand Up @@ -377,13 +379,13 @@ impl<'a, T: 'static> Builder<'a, T> {
///
/// [`AbortHandle`]: crate::task::AbortHandle
#[track_caller]
pub fn spawn<F>(self, future: F) -> AbortHandle
pub fn spawn<F>(self, future: F) -> std::io::Result<AbortHandle>
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
self.joinset.insert(self.builder.spawn(future))
Ok(self.joinset.insert(self.builder.spawn(future)?))
}

/// Spawn the provided task on the provided [runtime handle] with this
Expand All @@ -397,13 +399,13 @@ impl<'a, T: 'static> Builder<'a, T> {
/// [`AbortHandle`]: crate::task::AbortHandle
/// [runtime handle]: crate::runtime::Handle
#[track_caller]
pub fn spawn_on<F>(mut self, future: F, handle: &Handle) -> AbortHandle
pub fn spawn_on<F>(mut self, future: F, handle: &Handle) -> std::io::Result<AbortHandle>
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
self.joinset.insert(self.builder.spawn_on(future, handle))
Ok(self.joinset.insert(self.builder.spawn_on(future, handle)?))
}

/// Spawn the provided task on the current [`LocalSet`] with this builder's
Expand All @@ -420,12 +422,12 @@ impl<'a, T: 'static> Builder<'a, T> {
/// [`LocalSet`]: crate::task::LocalSet
/// [`AbortHandle`]: crate::task::AbortHandle
#[track_caller]
pub fn spawn_local<F>(self, future: F) -> AbortHandle
pub fn spawn_local<F>(self, future: F) -> std::io::Result<AbortHandle>
where
F: Future<Output = T>,
F: 'static,
{
self.joinset.insert(self.builder.spawn_local(future))
Ok(self.joinset.insert(self.builder.spawn_local(future)?))
}

/// Spawn the provided task on the provided [`LocalSet`] with this builder's
Expand All @@ -438,13 +440,14 @@ impl<'a, T: 'static> Builder<'a, T> {
/// [`LocalSet`]: crate::task::LocalSet
/// [`AbortHandle`]: crate::task::AbortHandle
#[track_caller]
pub fn spawn_local_on<F>(self, future: F, local_set: &LocalSet) -> AbortHandle
pub fn spawn_local_on<F>(self, future: F, local_set: &LocalSet) -> std::io::Result<AbortHandle>
where
F: Future<Output = T>,
F: 'static,
{
self.joinset
.insert(self.builder.spawn_local_on(future, local_set))
Ok(self
.joinset
.insert(self.builder.spawn_local_on(future, local_set)?))
}
}

Expand Down
20 changes: 17 additions & 3 deletions tokio/tests/task_builder.rs
Expand Up @@ -11,6 +11,7 @@ mod tests {
let result = Builder::new()
.name("name")
.spawn(async { "task executed" })
.unwrap()
.await;

assert_eq!(result.unwrap(), "task executed");
Expand All @@ -21,6 +22,7 @@ mod tests {
let result = Builder::new()
.name("name")
.spawn_blocking(|| "task executed")
.unwrap()
.await;

assert_eq!(result.unwrap(), "task executed");
Expand All @@ -34,6 +36,7 @@ mod tests {
Builder::new()
.name("name")
.spawn_local(async move { unsend_data })
.unwrap()
.await
})
.await;
Expand All @@ -43,14 +46,20 @@ mod tests {

#[test]
async fn spawn_without_name() {
let result = Builder::new().spawn(async { "task executed" }).await;
let result = Builder::new()
.spawn(async { "task executed" })
.unwrap()
.await;

assert_eq!(result.unwrap(), "task executed");
}

#[test]
async fn spawn_blocking_without_name() {
let result = Builder::new().spawn_blocking(|| "task executed").await;
let result = Builder::new()
.spawn_blocking(|| "task executed")
.unwrap()
.await;

assert_eq!(result.unwrap(), "task executed");
}
Expand All @@ -59,7 +68,12 @@ mod tests {
async fn spawn_local_without_name() {
let unsend_data = Rc::new("task executed");
let result = LocalSet::new()
.run_until(async move { Builder::new().spawn_local(async move { unsend_data }).await })
.run_until(async move {
Builder::new()
.spawn_local(async move { unsend_data })
.unwrap()
.await
})
.await;

assert_eq!(*result.unwrap(), "task executed");
Expand Down