A task group as lovely as a tree

As part of a larger program, I recently had a need for a task group that:

  • Limits the number of tasks active at once

  • Returns a stream of task return values

  • Aborts all tasks when the stream is dropped

  • Gives each task a way to spawn more tasks in the same group (giving rise to a directed acyclic graph or tree of tasks, hence the naming)

Not-entirely-thorough searching didn't turn up any pre-existing matches on crates.io, so I ended up writing my own. It seems to work fine, but with this much asyncery, I thought I'd get some more eyes on it.

Aside from a need for general code review advice, I'm concerned about the following:

  • Is it a good idea to block on std::sync::Mutex::lock() inside a poll method? The mutex should only ever be locked for very small amounts of time (I think; how fast is JoinSet::spawn?), but it's theoretically possible that a swarm of tasks spawning in quick succession could hog the mutex and keep poll_next() from returning for a while.

  • Is it a good idea to check whether all the Spawners have been dropped by seeing whether the strong count of Arc<Mutex<JoinSet>> is one? I don't believe there would be any false negatives with this approach, but I'm open to being proven wrong.

use futures_util::Stream;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex, PoisonError};
use std::task::{ready, Context, Poll};
use tokio::{sync::Semaphore, task::JoinSet};

/// A task group with the following properties:
///
/// - No more than a certain number of tasks are ever active at once.
///
/// - Each task is passed a `Spawner` that can be used to spawn more tasks in
///   the group.
///
/// - `BoundedTreeNursery<T>` is a `Stream` of the return values of the tasks
///   (which must all be `T`).
///
/// - Dropping `BoundedTreeNursery` causes all tasks to be aborted.
#[derive(Clone, Debug)]
pub(crate) struct BoundedTreeNursery<T> {
    tasks: Arc<Mutex<JoinSet<T>>>,
}

impl<T: Send + 'static> BoundedTreeNursery<T> {
    /// Create a `BoundedTreeNursery` that limits the number of active tasks to
    /// at most `limit` and with `root` spawned as the initial task
    pub(crate) fn new<F, Fut>(limit: usize, root: F) -> Self
    where
        F: FnOnce(Spawner<T>) -> Fut + Send + 'static,
        Fut: Future<Output = T> + Send + 'static,
    {
        let semaphore = Arc::new(Semaphore::new(limit));
        let tasks = Arc::new(Mutex::new(JoinSet::new()));
        let spawner = Spawner {
            semaphore,
            tasks: tasks.clone(),
        };
        spawner.spawn_with_self(root);
        BoundedTreeNursery { tasks }
    }
}

impl<T: 'static> Stream for BoundedTreeNursery<T> {
    type Item = T;

    /// Poll for one of the tasks in the group to complete and return its
    /// return value.
    ///
    /// # Panics
    ///
    /// If a task panics, this method resumes unwinding the panic.
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
        let mut tasks = self.tasks.lock().unwrap_or_else(PoisonError::into_inner);
        match ready!(tasks.poll_join_next(cx)) {
            Some(Ok(r)) => Some(r).into(),
            Some(Err(e)) => match e.try_into_panic() {
                Ok(barf) => std::panic::resume_unwind(barf),
                Err(e) => unreachable!(
                    "Task in BoundedTreeNursery should not have been aborted, but got {e:?}"
                ),
            },
            None => {
                if Arc::strong_count(&self.tasks) == 1 {
                    // All spawners dropped and all results yielded; end of
                    // stream
                    None.into()
                } else {
                    Poll::Pending
                }
            }
        }
    }
}

/// A handle for spawning tasks in a `BoundedTreeNursery<T>`
#[derive(Debug)]
pub(crate) struct Spawner<T> {
    semaphore: Arc<Semaphore>,
    tasks: Arc<Mutex<JoinSet<T>>>,
}

// Clone can't be derived, as that would erroneously add `T: Clone` bounds to
// the impl.
impl<T> Clone for Spawner<T> {
    fn clone(&self) -> Spawner<T> {
        Spawner {
            semaphore: self.semaphore.clone(),
            tasks: self.tasks.clone(),
        }
    }
}

impl<T: Send + 'static> Spawner<T> {
    /// Spawn the given task in the task group, passing it a new `Spawner`
    pub(crate) fn spawn<F, Fut>(&self, func: F)
    where
        F: FnOnce(Spawner<T>) -> Fut + Send + 'static,
        Fut: Future<Output = T> + Send + 'static,
    {
        self.clone().spawn_with_self(func);
    }

    /// Spawn the given task in the task group, passing it this `Spawner`
    fn spawn_with_self<F, Fut>(self, func: F)
    where
        F: FnOnce(Spawner<T>) -> Fut + Send + 'static,
        Fut: Future<Output = T> + Send + 'static,
    {
        let semaphore = self.semaphore.clone();
        let tasks = self.tasks.clone();
        let mut tasks = tasks.lock().unwrap_or_else(PoisonError::into_inner);
        tasks.spawn(async move {
            let Ok(_permit) = semaphore.acquire().await else {
                unreachable!("Semaphore should not be closed");
            };
            func(self).await
        });
    }
}

A partially-stubbed example of how this can be used to traverse a remote file
hierarchy:

use futures_util::{future::BoxFuture, FutureExt, TryStreamExt};
use std::fmt;
use url::Url;

async fn traverse(base_url: Url, workers: usize) -> anyhow::Result<()> {
    let client = Client;
    let mut stream = BoundedTreeNursery::new(workers, move |spawner| {
        process_dir(spawner, client, base_url)
    });
    while let Some(r) = stream.try_next().await? {
        println!("{r}");
    }
    Ok(())
}

fn process_dir(
    spawner: Spawner<anyhow::Result<Report>>,
    client: Client,
    url: Url,
) -> BoxFuture<'static, anyhow::Result<Report>> {
    // We need to return a boxed Future in order to be able to call
    // `process_dir()` inside itself.
    async move {
        let dl = client.list_directory(url).await?;
        for d in dl.directories {
            let cl2 = client.clone();
            spawner.spawn(move |spawner| Box::pin(process_dir(spawner, cl2, d)));
        }
        for f in dl.files {
            let cl2 = client.clone();
            spawner.spawn(move |_spawner| process_file(cl2, f));
        }
        Ok(Report)
    }
    .boxed()
}

async fn process_file(client: Client, url: Url) -> anyhow::Result<Report> {
    client.get_file(url).await
}

/// Stub client for traversing a web hierarchy of some sort (e.g., WebDAV)
#[derive(Clone, Debug)]
struct Client;

impl Client {
    async fn list_directory(&self, url: Url) -> anyhow::Result<DirectoryListing<Url>> {
        todo!()
    }

    async fn get_file(&self, url: Url) -> anyhow::Result<Report> {
        todo!()
    }
}

struct DirectoryListing {
    directories: Vec<Url>,
    files: Vec<Url>,
}

/// Information returned for each request
struct Report;

impl fmt::Display for Report {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        todo!()
    }
}
1 Like

This should be doable by feeding a channel of "to be spawned" futures into StreamExt::buffer_unordered() . Possibly with some helpers, but at least you'd be able to avoid doing your own counting and signalling.

If by "'to be spawned' futures" you mean plain Future objects (i.e., tokio::spawn() isn't actually involved), I believe that would have the same footgun as FuturesUnordered (ref1, ref2): active futures aren't polled while the while body is running, so if the body runs for too long, network actions could time out, among other possible issues. I actually got bitten by this once, so now I avoid FuturesUnordered and buffer_unordered() unless I can be very, very sure this problem won't crop up.

I wasn't sure whether you needed actual spawning, or merely concurrency. If you do, that's easy to fix, if somewhat clunky:

stream
    .map(|fut| async move { tokio::spawn(fut).await })
    .buffer_unordered(n)

Now the spawns happen in buffered fashion.

But with that snippet, dropping the stream no longer aborts the active tasks. I don't see a quick fix for that, do you?

Sorry, I was too hasty to respond and left out some further details about the “some helpers” I mentioned in my original post. You can achieve that, and have something useful in other contexts too, by wrapping a JoinHandle so as to call abort() if it is dropped, turning the auto-detaching spawned task into a structured-concurrency task.

(Perhaps this is too many parts. But I like code that achieves its goal by composing small, single-purpose, robust components that are easy to comprehend in isolation.)

1 Like