diff --git a/crates/bevy_ecs/src/schedule/executor_parallel.rs b/crates/bevy_ecs/src/schedule/executor_parallel.rs index 1663d26646241..09e8f07301170 100644 --- a/crates/bevy_ecs/src/schedule/executor_parallel.rs +++ b/crates/bevy_ecs/src/schedule/executor_parallel.rs @@ -166,7 +166,7 @@ impl ParallelExecutor { /// queues systems with no dependencies to run (or skip) at next opportunity. fn prepare_systems<'scope>( &mut self, - scope: &mut Scope<'scope, ()>, + scope: &Scope<'_, 'scope, ()>, systems: &'scope mut [SystemContainer], world: &'scope World, ) { @@ -236,7 +236,7 @@ impl ParallelExecutor { if system_data.is_send { scope.spawn(task); } else { - scope.spawn_local(task); + scope.spawn_on_scope(task); } #[cfg(test)] @@ -271,7 +271,7 @@ impl ParallelExecutor { if system_data.is_send { scope.spawn(task); } else { - scope.spawn_local(task); + scope.spawn_on_scope(task); } } } diff --git a/crates/bevy_tasks/Cargo.toml b/crates/bevy_tasks/Cargo.toml index 91f1e8e5e334c..0358f95c1b458 100644 --- a/crates/bevy_tasks/Cargo.toml +++ b/crates/bevy_tasks/Cargo.toml @@ -13,6 +13,7 @@ futures-lite = "1.4.0" async-executor = "1.3.0" async-channel = "1.4.2" once_cell = "1.7" +concurrent-queue = "1.2.2" [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen-futures = "0.4" diff --git a/crates/bevy_tasks/src/single_threaded_task_pool.rs b/crates/bevy_tasks/src/single_threaded_task_pool.rs index 757b711d999d3..8fa37f4f2361b 100644 --- a/crates/bevy_tasks/src/single_threaded_task_pool.rs +++ b/crates/bevy_tasks/src/single_threaded_task_pool.rs @@ -1,5 +1,6 @@ use std::{ future::Future, + marker::PhantomData, mem, sync::{Arc, Mutex}, }; @@ -61,27 +62,34 @@ impl TaskPool { /// to spawn tasks. This function will await the completion of all tasks before returning. /// /// This is similar to `rayon::scope` and `crossbeam::scope` - pub fn scope<'scope, F, T>(&self, f: F) -> Vec + pub fn scope<'env, F, T>(&self, f: F) -> Vec where - F: FnOnce(&mut Scope<'scope, T>) + 'scope + Send, + F: for<'scope> FnOnce(&'env mut Scope<'scope, 'env, T>), T: Send + 'static, { let executor = &async_executor::LocalExecutor::new(); - let executor: &'scope async_executor::LocalExecutor<'scope> = + let executor: &'env async_executor::LocalExecutor<'env> = unsafe { mem::transmute(executor) }; + let results: Mutex>>>> = Mutex::new(Vec::new()); + let results: &'env Mutex>>>> = unsafe { mem::transmute(&results) }; + let mut scope = Scope { executor, - results: Vec::new(), + results, + scope: PhantomData, + env: PhantomData, }; - f(&mut scope); + let scope_ref: &'env mut Scope<'_, 'env, T> = unsafe { mem::transmute(&mut scope) }; + + f(scope_ref); // Loop until all tasks are done while executor.try_tick() {} - scope - .results + let results = scope.results.lock().unwrap(); + results .iter() .map(|result| result.lock().unwrap().take().unwrap()) .collect() @@ -127,13 +135,17 @@ impl FakeTask { /// /// For more information, see [`TaskPool::scope`]. #[derive(Debug)] -pub struct Scope<'scope, T> { - executor: &'scope async_executor::LocalExecutor<'scope>, +pub struct Scope<'scope, 'env: 'scope, T> { + executor: &'env async_executor::LocalExecutor<'env>, // Vector to gather results of all futures spawned during scope run - results: Vec>>>, + results: &'env Mutex>>>>, + + // make `Scope` invariant over 'scope and 'env + scope: PhantomData<&'scope mut &'scope ()>, + env: PhantomData<&'env mut &'env ()>, } -impl<'scope, T: Send + 'scope> Scope<'scope, T> { +impl<'scope, 'env, T: Send + 'env> Scope<'scope, 'env, T> { /// Spawns a scoped future onto the thread-local executor. The scope *must* outlive /// the provided future. The results of the future will be returned as a part of /// [`TaskPool::scope`]'s return value. @@ -141,18 +153,18 @@ impl<'scope, T: Send + 'scope> Scope<'scope, T> { /// On the single threaded task pool, it just calls [`Scope::spawn_local`]. /// /// For more information, see [`TaskPool::scope`]. - pub fn spawn + 'scope + Send>(&mut self, f: Fut) { - self.spawn_local(f); + pub fn spawn + 'env>(&self, f: Fut) { + self.spawn_on_scope(f); } - /// Spawns a scoped future onto the thread-local executor. The scope *must* outlive - /// the provided future. The results of the future will be returned as a part of - /// [`TaskPool::scope`]'s return value. + /// Spawns a scoped future that runs on the thread the scope called from. The + /// scope *must* outlive the provided future. The results of the future will be + /// returned as a part of [`TaskPool::scope`]'s return value. /// /// For more information, see [`TaskPool::scope`]. - pub fn spawn_local + 'scope>(&mut self, f: Fut) { + pub fn spawn_on_scope + 'env>(&self, f: Fut) { let result = Arc::new(Mutex::new(None)); - self.results.push(result.clone()); + self.results.lock().unwrap().push(result.clone()); let f = async move { result.lock().unwrap().replace(f.await); }; diff --git a/crates/bevy_tasks/src/task_pool.rs b/crates/bevy_tasks/src/task_pool.rs index c9c97322f56f7..e0ac0101d56ac 100644 --- a/crates/bevy_tasks/src/task_pool.rs +++ b/crates/bevy_tasks/src/task_pool.rs @@ -1,11 +1,13 @@ use std::{ future::Future, + marker::PhantomData, mem, pin::Pin, sync::Arc, thread::{self, JoinHandle}, }; +use concurrent_queue::ConcurrentQueue; use futures_lite::{future, pin}; use crate::Task; @@ -140,69 +142,145 @@ impl TaskPool { /// to spawn tasks. This function will await the completion of all tasks before returning. /// /// This is similar to `rayon::scope` and `crossbeam::scope` - pub fn scope<'scope, F, T>(&self, f: F) -> Vec + /// + /// # Example + /// + /// ``` + /// use bevy_tasks::TaskPool; + /// + /// let pool = TaskPool::new(); + /// let mut x = 0; + /// let results = pool.scope(|s| { + /// s.spawn(async { + /// // you can borrow the spawner inside a task and spawn tasks from within the task + /// s.spawn(async { + /// // borrow x and mutate it. + /// x = 2; + /// // return a value from the task + /// 1 + /// }); + /// // return some other value from the first task + /// 0 + /// }); + /// }); + /// + /// // results are returned in the order the tasks are spawned in. + /// // Note: the ordering may become non-deterministic if you spawn from within tasks. + /// // the ordering is only guaranteed when tasks are spawned directly from the main closure. + /// assert_eq!(&results[..], &[0, 1]); + /// // can access x after scope runs + /// assert_eq!(x, 2); + /// ``` + /// + /// # Lifetimes + /// + /// The [`Scope`] object takes two lifetimes: `'scope` and `'env`. + /// + /// The `'scope` lifetime represents the lifetime of the scope. That is the time during + /// which the provided closure and tasks that are spawned into the scope are run. + /// + /// The `'env` lifetime represents the lifetime of whatever is borrowed by the scope. + /// Thus this lifetime must outlive `'scope`. + /// + /// ```compile_fail + /// use bevy_tasks::TaskPool; + /// fn scope_escapes_closure() { + /// let pool = TaskPool::new(); + /// let foo = Box::new(42); + /// pool.scope(|scope| { + /// std::thread::spawn(move || { + /// // UB. This could spawn on the scope after `.scope` returns and the internal Scope is dropped. + /// scope.spawn(async move { + /// assert_eq!(*foo, 42); + /// }); + /// }); + /// }); + /// } + /// ``` + /// + /// ```compile_fail + /// use bevy_tasks::TaskPool; + /// fn cannot_borrow_from_closure() { + /// let pool = TaskPool::new(); + /// pool.scope(|scope| { + /// let x = 1; + /// let y = &x; + /// scope.spawn(async move { + /// assert_eq!(*y, 1); + /// }); + /// }); + /// } + /// + pub fn scope<'env, F, T>(&self, f: F) -> Vec where - F: FnOnce(&mut Scope<'scope, T>) + 'scope + Send, + F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>), T: Send + 'static, { - TaskPool::LOCAL_EXECUTOR.with(|local_executor| { - // SAFETY: This function blocks until all futures complete, so this future must return - // before this function returns. However, rust has no way of knowing - // this so we must convert to 'static here to appease the compiler as it is unable to - // validate safety. - let executor: &async_executor::Executor = &self.executor; - let executor: &'scope async_executor::Executor = unsafe { mem::transmute(executor) }; - let local_executor: &'scope async_executor::LocalExecutor = - unsafe { mem::transmute(local_executor) }; - let mut scope = Scope { - executor, - local_executor, - spawned: Vec::new(), - }; + // SAFETY: This safety comment applies to all references transmuted to 'env. + // Any futures spawned with these references need to return before this function completes. + // This is guaranteed because we drive all the futures spawned onto the Scope + // to completion in this function. However, rust has no way of knowing this so we + // transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety. + let executor: &async_executor::Executor = &*self.executor; + let executor: &'env async_executor::Executor = unsafe { mem::transmute(executor) }; + let task_scope_executor = &async_executor::Executor::default(); + let task_scope_executor: &'env async_executor::Executor = + unsafe { mem::transmute(task_scope_executor) }; + let spawned: ConcurrentQueue> = ConcurrentQueue::unbounded(); + let spawned_ref: &'env ConcurrentQueue> = + unsafe { mem::transmute(&spawned) }; + + let scope = Scope { + executor, + task_scope_executor, + spawned: spawned_ref, + scope: PhantomData, + env: PhantomData, + }; + + let scope_ref: &'env Scope<'_, 'env, T> = unsafe { mem::transmute(&scope) }; + + f(scope_ref); + + if spawned.is_empty() { + Vec::new() + } else { + let get_results = async move { + let mut results = Vec::with_capacity(spawned.len()); + while let Ok(task) = spawned.pop() { + results.push(task.await); + } - f(&mut scope); - - if scope.spawned.is_empty() { - Vec::default() - } else if scope.spawned.len() == 1 { - vec![future::block_on(&mut scope.spawned[0])] - } else { - let fut = async move { - let mut results = Vec::with_capacity(scope.spawned.len()); - for task in scope.spawned { - results.push(task.await); - } + results + }; - results + // Pin the futures on the stack. + pin!(get_results); + + // SAFETY: This function blocks until all futures complete, so we do not read/write + // the data from futures outside of the 'scope lifetime. However, + // rust has no way of knowing this so we must convert to 'static + // here to appease the compiler as it is unable to validate safety. + let get_results: Pin<&mut (dyn Future> + 'static + Send)> = get_results; + let get_results: Pin<&'static mut (dyn Future> + 'static + Send)> = + unsafe { mem::transmute(get_results) }; + + // The thread that calls scope() will participate in driving tasks in the pool + // forward until the tasks that are spawned by this scope() call + // complete. (If the caller of scope() happens to be a thread in + // this thread pool, and we only have one thread in the pool, then + // simply calling future::block_on(spawned) would deadlock.) + let mut spawned = task_scope_executor.spawn(get_results); + + loop { + if let Some(result) = future::block_on(future::poll_once(&mut spawned)) { + break result; }; - // Pin the futures on the stack. - pin!(fut); - - // SAFETY: This function blocks until all futures complete, so we do not read/write - // the data from futures outside of the 'scope lifetime. However, - // rust has no way of knowing this so we must convert to 'static - // here to appease the compiler as it is unable to validate safety. - let fut: Pin<&mut (dyn Future>)> = fut; - let fut: Pin<&'static mut (dyn Future> + 'static)> = - unsafe { mem::transmute(fut) }; - - // The thread that calls scope() will participate in driving tasks in the pool - // forward until the tasks that are spawned by this scope() call - // complete. (If the caller of scope() happens to be a thread in - // this thread pool, and we only have one thread in the pool, then - // simply calling future::block_on(spawned) would deadlock.) - let mut spawned = local_executor.spawn(fut); - loop { - if let Some(result) = future::block_on(future::poll_once(&mut spawned)) { - break result; - }; - - self.executor.try_tick(); - local_executor.try_tick(); - } + self.executor.try_tick(); + task_scope_executor.try_tick(); } - }) + } } /// Spawns a static future onto the thread pool. The returned Task is a future. It can also be @@ -254,35 +332,42 @@ impl Drop for TaskPool { /// /// For more information, see [`TaskPool::scope`]. #[derive(Debug)] -pub struct Scope<'scope, T> { +pub struct Scope<'scope, 'env: 'scope, T> { executor: &'scope async_executor::Executor<'scope>, - local_executor: &'scope async_executor::LocalExecutor<'scope>, - spawned: Vec>, + task_scope_executor: &'scope async_executor::Executor<'scope>, + spawned: &'scope ConcurrentQueue>, + // make `Scope` invariant over 'scope and 'env + scope: PhantomData<&'scope mut &'scope ()>, + env: PhantomData<&'env mut &'env ()>, } -impl<'scope, T: Send + 'scope> Scope<'scope, T> { +impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> { /// Spawns a scoped future onto the thread pool. The scope *must* outlive /// the provided future. The results of the future will be returned as a part of /// [`TaskPool::scope`]'s return value. /// - /// If the provided future is non-`Send`, [`Scope::spawn_local`] should be used + /// For futures that should run on the thread `scope` is called on [`Scope::spawn_on_scope`] should be used /// instead. /// /// For more information, see [`TaskPool::scope`]. - pub fn spawn + 'scope + Send>(&mut self, f: Fut) { + pub fn spawn + 'scope + Send>(&self, f: Fut) { let task = self.executor.spawn(f); - self.spawned.push(task); + // ConcurrentQueue only errors when closed or full, but we never + // close and use an unbouded queue, so it is safe to unwrap + self.spawned.push(task).unwrap(); } - /// Spawns a scoped future onto the thread-local executor. The scope *must* outlive + /// Spawns a scoped future onto the thread the scope is run on. The scope *must* outlive /// the provided future. The results of the future will be returned as a part of /// [`TaskPool::scope`]'s return value. Users should generally prefer to use - /// [`Scope::spawn`] instead, unless the provided future is not `Send`. + /// [`Scope::spawn`] instead, unless the provided future needs to run on the scope's thread. /// /// For more information, see [`TaskPool::scope`]. - pub fn spawn_local + 'scope>(&mut self, f: Fut) { - let task = self.local_executor.spawn(f); - self.spawned.push(task); + pub fn spawn_on_scope + 'scope + Send>(&self, f: Fut) { + let task = self.task_scope_executor.spawn(f); + // ConcurrentQueue only errors when closed or full, but we never + // close and use an unbouded queue, so it is safe to unwrap + self.spawned.push(task).unwrap(); } } @@ -327,7 +412,7 @@ mod tests { } #[test] - fn test_mixed_spawn_local_and_spawn() { + fn test_mixed_spawn_on_scope_and_spawn() { let pool = TaskPool::new(); let foo = Box::new(42); @@ -350,7 +435,7 @@ mod tests { }); } else { let count_clone = local_count.clone(); - scope.spawn_local(async move { + scope.spawn_on_scope(async move { if *foo != 42 { panic!("not 42!?!?") } else { @@ -391,7 +476,7 @@ mod tests { }); let spawner = std::thread::current().id(); let inner_count_clone = count_clone.clone(); - scope.spawn_local(async move { + scope.spawn_on_scope(async move { inner_count_clone.fetch_add(1, Ordering::Release); if std::thread::current().id() != spawner { // NOTE: This check is using an atomic rather than simply panicing the @@ -407,4 +492,80 @@ mod tests { assert!(!thread_check_failed.load(Ordering::Acquire)); assert_eq!(count.load(Ordering::Acquire), 200); } + + #[test] + fn test_nested_spawn() { + let pool = TaskPool::new(); + + let foo = Box::new(42); + let foo = &*foo; + + let count = Arc::new(AtomicI32::new(0)); + + let outputs: Vec = pool.scope(|scope| { + for _ in 0..10 { + let count_clone = count.clone(); + scope.spawn(async move { + for _ in 0..10 { + let count_clone_clone = count_clone.clone(); + scope.spawn(async move { + if *foo != 42 { + panic!("not 42!?!?") + } else { + count_clone_clone.fetch_add(1, Ordering::Relaxed); + *foo + } + }); + } + *foo + }); + } + }); + + for output in &outputs { + assert_eq!(*output, 42); + } + + // the inner loop runs 100 times and the outer one runs 10. 100 + 10 + assert_eq!(outputs.len(), 110); + assert_eq!(count.load(Ordering::Relaxed), 100); + } + + #[test] + fn test_nested_locality() { + let pool = Arc::new(TaskPool::new()); + let count = Arc::new(AtomicI32::new(0)); + let barrier = Arc::new(Barrier::new(101)); + let thread_check_failed = Arc::new(AtomicBool::new(false)); + + for _ in 0..100 { + let inner_barrier = barrier.clone(); + let count_clone = count.clone(); + let inner_pool = pool.clone(); + let inner_thread_check_failed = thread_check_failed.clone(); + std::thread::spawn(move || { + inner_pool.scope(|scope| { + let spawner = std::thread::current().id(); + let inner_count_clone = count_clone.clone(); + scope.spawn(async move { + inner_count_clone.fetch_add(1, Ordering::Release); + + // spawning on the scope from another thread runs the futures on the scope's thread + scope.spawn_on_scope(async move { + inner_count_clone.fetch_add(1, Ordering::Release); + if std::thread::current().id() != spawner { + // NOTE: This check is using an atomic rather than simply panicing the + // thread to avoid deadlocking the barrier on failure + inner_thread_check_failed.store(true, Ordering::Release); + } + }); + }); + }); + inner_barrier.wait(); + }); + } + barrier.wait(); + assert!(!thread_check_failed.load(Ordering::Acquire)); + assert_eq!(count.load(Ordering::Acquire), 200); + } }