As part of a larger program, I recently had a need for a task group that:
-
Limits the number of tasks active at once
-
Returns a stream of task return values
-
Aborts all tasks when the stream is dropped
-
Gives each task a way to spawn more tasks in the same group (giving rise to a directed acyclic graph or tree of tasks, hence the naming)
Not-entirely-thorough searching didn't turn up any pre-existing matches on crates.io, so I ended up writing my own. It seems to work fine, but with this much asyncery, I thought I'd get some more eyes on it.
Aside from a need for general code review advice, I'm concerned about the following:
-
Is it a good idea to block on
std::sync::Mutex::lock()
inside a poll method? The mutex should only ever be locked for very small amounts of time (I think; how fast isJoinSet::spawn
?), but it's theoretically possible that a swarm of tasks spawning in quick succession could hog the mutex and keeppoll_next()
from returning for a while. -
Is it a good idea to check whether all the
Spawner
s have been dropped by seeing whether the strong count ofArc<Mutex<JoinSet>>
is one? I don't believe there would be any false negatives with this approach, but I'm open to being proven wrong.
use futures_util::Stream;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex, PoisonError};
use std::task::{ready, Context, Poll};
use tokio::{sync::Semaphore, task::JoinSet};
/// A task group with the following properties:
///
/// - No more than a certain number of tasks are ever active at once.
///
/// - Each task is passed a `Spawner` that can be used to spawn more tasks in
/// the group.
///
/// - `BoundedTreeNursery<T>` is a `Stream` of the return values of the tasks
/// (which must all be `T`).
///
/// - Dropping `BoundedTreeNursery` causes all tasks to be aborted.
#[derive(Clone, Debug)]
pub(crate) struct BoundedTreeNursery<T> {
tasks: Arc<Mutex<JoinSet<T>>>,
}
impl<T: Send + 'static> BoundedTreeNursery<T> {
/// Create a `BoundedTreeNursery` that limits the number of active tasks to
/// at most `limit` and with `root` spawned as the initial task
pub(crate) fn new<F, Fut>(limit: usize, root: F) -> Self
where
F: FnOnce(Spawner<T>) -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
{
let semaphore = Arc::new(Semaphore::new(limit));
let tasks = Arc::new(Mutex::new(JoinSet::new()));
let spawner = Spawner {
semaphore,
tasks: tasks.clone(),
};
spawner.spawn_with_self(root);
BoundedTreeNursery { tasks }
}
}
impl<T: 'static> Stream for BoundedTreeNursery<T> {
type Item = T;
/// Poll for one of the tasks in the group to complete and return its
/// return value.
///
/// # Panics
///
/// If a task panics, this method resumes unwinding the panic.
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
let mut tasks = self.tasks.lock().unwrap_or_else(PoisonError::into_inner);
match ready!(tasks.poll_join_next(cx)) {
Some(Ok(r)) => Some(r).into(),
Some(Err(e)) => match e.try_into_panic() {
Ok(barf) => std::panic::resume_unwind(barf),
Err(e) => unreachable!(
"Task in BoundedTreeNursery should not have been aborted, but got {e:?}"
),
},
None => {
if Arc::strong_count(&self.tasks) == 1 {
// All spawners dropped and all results yielded; end of
// stream
None.into()
} else {
Poll::Pending
}
}
}
}
}
/// A handle for spawning tasks in a `BoundedTreeNursery<T>`
#[derive(Debug)]
pub(crate) struct Spawner<T> {
semaphore: Arc<Semaphore>,
tasks: Arc<Mutex<JoinSet<T>>>,
}
// Clone can't be derived, as that would erroneously add `T: Clone` bounds to
// the impl.
impl<T> Clone for Spawner<T> {
fn clone(&self) -> Spawner<T> {
Spawner {
semaphore: self.semaphore.clone(),
tasks: self.tasks.clone(),
}
}
}
impl<T: Send + 'static> Spawner<T> {
/// Spawn the given task in the task group, passing it a new `Spawner`
pub(crate) fn spawn<F, Fut>(&self, func: F)
where
F: FnOnce(Spawner<T>) -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
{
self.clone().spawn_with_self(func);
}
/// Spawn the given task in the task group, passing it this `Spawner`
fn spawn_with_self<F, Fut>(self, func: F)
where
F: FnOnce(Spawner<T>) -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
{
let semaphore = self.semaphore.clone();
let tasks = self.tasks.clone();
let mut tasks = tasks.lock().unwrap_or_else(PoisonError::into_inner);
tasks.spawn(async move {
let Ok(_permit) = semaphore.acquire().await else {
unreachable!("Semaphore should not be closed");
};
func(self).await
});
}
}
A partially-stubbed example of how this can be used to traverse a remote file
hierarchy:
use futures_util::{future::BoxFuture, FutureExt, TryStreamExt};
use std::fmt;
use url::Url;
async fn traverse(base_url: Url, workers: usize) -> anyhow::Result<()> {
let client = Client;
let mut stream = BoundedTreeNursery::new(workers, move |spawner| {
process_dir(spawner, client, base_url)
});
while let Some(r) = stream.try_next().await? {
println!("{r}");
}
Ok(())
}
fn process_dir(
spawner: Spawner<anyhow::Result<Report>>,
client: Client,
url: Url,
) -> BoxFuture<'static, anyhow::Result<Report>> {
// We need to return a boxed Future in order to be able to call
// `process_dir()` inside itself.
async move {
let dl = client.list_directory(url).await?;
for d in dl.directories {
let cl2 = client.clone();
spawner.spawn(move |spawner| Box::pin(process_dir(spawner, cl2, d)));
}
for f in dl.files {
let cl2 = client.clone();
spawner.spawn(move |_spawner| process_file(cl2, f));
}
Ok(Report)
}
.boxed()
}
async fn process_file(client: Client, url: Url) -> anyhow::Result<Report> {
client.get_file(url).await
}
/// Stub client for traversing a web hierarchy of some sort (e.g., WebDAV)
#[derive(Clone, Debug)]
struct Client;
impl Client {
async fn list_directory(&self, url: Url) -> anyhow::Result<DirectoryListing<Url>> {
todo!()
}
async fn get_file(&self, url: Url) -> anyhow::Result<Report> {
todo!()
}
}
struct DirectoryListing {
directories: Vec<Url>,
files: Vec<Url>,
}
/// Information returned for each request
struct Report;
impl fmt::Display for Report {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
todo!()
}
}