A task group as lovely as a tree, take 2

Last week, I asked for a code review of a "bounded tree nursery" type I had written and got some pointers to an alternative approach in response. I've followed the general advice given there, with some modifications (see below).

As before, the goal of the code is to provide a task group that:

  • Limits the number of tasks active at once

  • Returns a stream of task return values

  • Aborts all tasks when the stream is dropped

  • Gives each task a way to spawn more tasks in the same group (giving rise to a tree of tasks, hence the naming)

@kpreid's suggestions involved creating a stream of JoinHandles returned from tokio::spawn() and batch-awaiting them using StreamExt::buffer_unordered(). However, under this approach, the buffer_unordered() has no limiting effect on the number of active tasks, only on the number of handles polled at once. I thus had to reintroduce the original code's approach of wrapping each task in a semaphore, at which point buffer_unordered() did more harm (due to not polling all spawned handles) than good.

My solution was to replace buffer_unordered() with a FuturesUnordered. At first, I wrapped this in an Arc that was cloned among Spawners, but when it came time to poll the FuturesUnordered, a mutable reference was needed, which meant I'd have to add a mutex, something I'm trying to get away from in this version of the code. Instead, I ended up giving each Spawner an UnboundedSender (unbounded so that Sender::spawn() could be sync) for sending JoinHandles (in an abort-on-drop wrapper), with the BoundedTreeNursery transferring all received handles to the FuturesUnordered on each poll.

Aside from a need for general code review advice, I'm concerned about the polling of the receiver alongside the polling of the FuturesUnordered. Is this the right thing to do? In particular, I had no idea what value to use for the poll_recv_many() limit, so I just picked 32 out of thin air. Would it be better for the limit to be the same value as (or a function of) the limit passed to BoundedTreeNursery::new()?

use futures_util::{stream::FuturesUnordered, Stream, StreamExt};
use pin_project_lite::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Context, Poll};
use tokio::{
    sync::{
        mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
        Semaphore,
    },
    task::{JoinError, JoinHandle},
};

/// A task group with the following properties:
///
/// - No more than a certain number of tasks are ever active at once.
///
/// - Each task is passed a `Spawner` that can be used to spawn more tasks in
///   the group.
///
/// - `BoundedTreeNursery<T>` is a `Stream` of the return values of the tasks
///   (which must all be `T`).
///
/// - Dropping `BoundedTreeNursery` causes all tasks to be aborted.
#[derive(Debug)]
pub(crate) struct BoundedTreeNursery<T> {
    receiver: UnboundedReceiver<FragileHandle<T>>,
    tasks: FuturesUnordered<FragileHandle<T>>,
    closed: bool,
}

impl<T: Send + 'static> BoundedTreeNursery<T> {
    /// Create a `BoundedTreeNursery` that limits the number of active tasks to
    /// at most `limit` and with `root` spawned as the initial task
    pub(crate) fn new<F, Fut>(limit: usize, root: F) -> Self
    where
        F: FnOnce(Spawner<T>) -> Fut + Send + 'static,
        Fut: Future<Output = T> + Send + 'static,
    {
        let semaphore = Arc::new(Semaphore::new(limit));
        let (sender, receiver) = unbounded_channel();
        let spawner = Spawner { semaphore, sender };
        spawner.spawn_with_self(root);
        BoundedTreeNursery {
            tasks: FuturesUnordered::new(),
            receiver,
            closed: false,
        }
    }
}

impl<T: Send + 'static> Stream for BoundedTreeNursery<T> {
    type Item = T;

    /// Poll for one of the tasks in the group to complete and return its
    /// return value.
    ///
    /// # Panics
    ///
    /// If a task panics, this method resumes unwinding the panic.
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
        let mut buf = Vec::new();
        match self.receiver.poll_recv_many(cx, &mut buf, 32) {
            Poll::Pending => (),
            Poll::Ready(0) => self.closed = true,
            Poll::Ready(_) => self.tasks.extend(buf),
        }
        match ready!(self.tasks.poll_next_unpin(cx)) {
            Some(Ok(r)) => Some(r).into(),
            Some(Err(e)) => match e.try_into_panic() {
                Ok(barf) => std::panic::resume_unwind(barf),
                Err(e) => unreachable!(
                    "Task in BoundedTreeNursery should not have been aborted, but got {e:?}"
                ),
            },
            None => {
                if self.closed {
                    // All spawners dropped and all results yielded; end of
                    // stream
                    None.into()
                } else {
                    Poll::Pending
                }
            }
        }
    }
}

/// A handle for spawning tasks in a `BoundedTreeNursery<T>`
#[derive(Debug)]
pub(crate) struct Spawner<T> {
    semaphore: Arc<Semaphore>,
    sender: UnboundedSender<FragileHandle<T>>,
}

// Clone can't be derived, as that would erroneously add `T: Clone` bounds to
// the impl.
impl<T> Clone for Spawner<T> {
    fn clone(&self) -> Spawner<T> {
        Spawner {
            semaphore: self.semaphore.clone(),
            sender: self.sender.clone(),
        }
    }
}

impl<T: Send + 'static> Spawner<T> {
    /// Spawn the given task in the task group, passing it a new `Spawner`
    pub(crate) fn spawn<F, Fut>(&self, func: F)
    where
        F: FnOnce(Spawner<T>) -> Fut + Send + 'static,
        Fut: Future<Output = T> + Send + 'static,
    {
        self.clone().spawn_with_self(func);
    }

    /// Spawn the given task in the task group, passing it this `Spawner`
    fn spawn_with_self<F, Fut>(self, func: F)
    where
        F: FnOnce(Spawner<T>) -> Fut + Send + 'static,
        Fut: Future<Output = T> + Send + 'static,
    {
        let semaphore = self.semaphore.clone();
        let sender = self.sender.clone();
        let _ = sender.send(FragileHandle::new(tokio::spawn(async move {
            let Ok(_permit) = semaphore.acquire().await else {
                unreachable!("Semaphore should not be closed");
            };
            func(self).await
        })));
    }
}

pin_project! {
    /// A wrapper around `tokio::task::JoinHandle` that aborts the task on drop.
    #[derive(Debug)]
    struct FragileHandle<T> {
        #[pin]
        inner: JoinHandle<T>
    }

    impl<T> PinnedDrop for FragileHandle<T> {
        fn drop(this: Pin<&mut Self>) {
            this.project().inner.abort();
        }
    }
}

impl<T> FragileHandle<T> {
    fn new(inner: JoinHandle<T>) -> Self {
        FragileHandle { inner }
    }
}

impl<T> Future for FragileHandle<T> {
    type Output = Result<T, JoinError>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();
        this.inner.poll(cx)
    }
}

A partially-stubbed example of how this can be used to traverse a remote file
hierarchy:

use futures_util::{future::BoxFuture, FutureExt, TryStreamExt};
use std::fmt;
use url::Url;

async fn traverse(base_url: Url, workers: usize) -> anyhow::Result<()> {
    let client = Client;
    let mut stream = BoundedTreeNursery::new(workers, move |spawner| {
        process_dir(spawner, client, base_url)
    });
    while let Some(r) = stream.try_next().await? {
        println!("{r}");
    }
    Ok(())
}

fn process_dir(
    spawner: Spawner<anyhow::Result<Report>>,
    client: Client,
    url: Url,
) -> BoxFuture<'static, anyhow::Result<Report>> {
    // We need to return a boxed Future in order to be able to call
    // `process_dir()` inside itself.
    async move {
        let dl = client.list_directory(url).await?;
        for d in dl.directories {
            let cl2 = client.clone();
            spawner.spawn(move |spawner| Box::pin(process_dir(spawner, cl2, d)));
        }
        for f in dl.files {
            let cl2 = client.clone();
            spawner.spawn(move |_spawner| process_file(cl2, f));
        }
        Ok(Report)
    }
    .boxed()
}

async fn process_file(client: Client, url: Url) -> anyhow::Result<Report> {
    client.get_file(url).await
}

/// Stub client for traversing a web hierarchy of some sort (e.g., WebDAV)
#[derive(Clone, Debug)]
struct Client;

impl Client {
    async fn list_directory(&self, url: Url) -> anyhow::Result<DirectoryListing<Url>> {
        todo!()
    }

    async fn get_file(&self, url: Url) -> anyhow::Result<Report> {
        todo!()
    }
}

struct DirectoryListing {
    directories: Vec<Url>,
    files: Vec<Url>,
}

/// Information returned for each request
struct Report;

impl fmt::Display for Report {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        todo!()
    }
}

I would like to suggest an alternate design:

  • Instead of using an mpsc channel to send join handles, instead use the mpsc channel to send the return values of the tasks. This way your Spawner doesn't need the FuturesUnordered. You don't need to store the JoinHandles and you just drop them. (You can still catch panics with catch_unwind.)
  • For cancellation, put a cancellation token drop guard in your Spawner, that your tasks will select on.
  • Still use a semaphore for limiting tasks.
1 Like

Your poll_next function is wrong. You should probably call poll_recv_many again in a loop until it returns Ready(0) or Pending. Otherwise you don't register for wakeups when a new message arrives on the channel.

Note that using poll_recv would also work fine here.

1 Like

So something like this?

use futures_util::Stream;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::{
    mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
    Semaphore,
};
use tokio_util::sync::{CancellationToken, DropGuard};

/// A task group with the following properties:
///
/// - No more than a certain number of tasks are ever active at once.
///
/// - Each task is passed a `Spawner` that can be used to spawn more tasks in
///   the group.
///
/// - `BoundedTreeNursery<T>` is a `Stream` of the return values of the tasks
///   (which must all be `T`).
///
/// - Dropping `BoundedTreeNursery` causes all tasks to be aborted.
#[derive(Debug)]
pub(crate) struct BoundedTreeNursery<T> {
    receiver: UnboundedReceiver<T>,
    _on_drop: DropGuard,
}

impl<T: Send + 'static> BoundedTreeNursery<T> {
    /// Create a `BoundedTreeNursery` that limits the number of active tasks to
    /// at most `limit` and with `root` spawned as the initial task
    pub(crate) fn new<F, Fut>(limit: usize, root: F) -> Self
    where
        F: FnOnce(Spawner<T>) -> Fut + Send + 'static,
        Fut: Future<Output = T> + Send + 'static,
    {
        let semaphore = Arc::new(Semaphore::new(limit));
        let token = CancellationToken::new();
        let on_drop = token.clone().drop_guard();
        let (sender, receiver) = unbounded_channel();
        let spawner = Spawner {
            semaphore,
            sender,
            token,
        };
        spawner.spawn_with_self(root);
        BoundedTreeNursery {
            receiver,
            _on_drop: on_drop,
        }
    }
}

impl<T: 'static> Stream for BoundedTreeNursery<T> {
    type Item = T;

    /// Poll for one of the tasks in the group to complete and return its
    /// return value.
    ///
    /// # Panics
    ///
    /// If a task panics, this method resumes unwinding the panic.
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
        self.receiver.poll_recv(cx)
    }
}

/// A handle for spawning tasks in a `BoundedTreeNursery<T>`
#[derive(Debug)]
pub(crate) struct Spawner<T> {
    semaphore: Arc<Semaphore>,
    sender: UnboundedSender<T>,
    token: CancellationToken,
}

// Clone can't be derived, as that would erroneously add `T: Clone` bounds to
// the impl.
impl<T> Clone for Spawner<T> {
    fn clone(&self) -> Spawner<T> {
        Spawner {
            semaphore: self.semaphore.clone(),
            sender: self.sender.clone(),
            token: self.token.clone(),
        }
    }
}

impl<T: Send + 'static> Spawner<T> {
    /// Spawn the given task in the task group, passing it a new `Spawner`
    pub(crate) fn spawn<F, Fut>(&self, func: F)
    where
        F: FnOnce(Spawner<T>) -> Fut + Send + 'static,
        Fut: Future<Output = T> + Send + 'static,
    {
        self.clone().spawn_with_self(func);
    }

    /// Spawn the given task in the task group, passing it this `Spawner`
    fn spawn_with_self<F, Fut>(self, func: F)
    where
        F: FnOnce(Spawner<T>) -> Fut + Send + 'static,
        Fut: Future<Output = T> + Send + 'static,
    {
        let Spawner {
            semaphore,
            sender,
            token,
        } = self.clone();
        let fut = async move {
            let Ok(_permit) = semaphore.acquire().await else {
                unreachable!("Semaphore should not be closed");
            };
            func(self).await
        };
        tokio::spawn(async move {
            tokio::select!(
                () = token.cancelled() => (),
                r = fut => {
                    let _ = sender.send(r);
                }
            );
        });
    }
}

Pretty much. My only question there would be whether you want to handle panics in fut.

1 Like

Right, I forgot about panics. I adjusted the last code as follows to handle them:

diff --git a/src/btn.rs b/src/btn.rs
index 4a9564e..95b61da 100644
--- a/src/btn.rs
+++ b/src/btn.rs
@@ -1,14 +1,16 @@
-use futures_util::Stream;
+use futures_util::{FutureExt, Stream};
 use std::future::Future;
 use std::pin::Pin;
 use std::sync::Arc;
-use std::task::{Context, Poll};
+use std::task::{ready, Context, Poll};
 use tokio::sync::{
     mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
     Semaphore,
 };
 use tokio_util::sync::{CancellationToken, DropGuard};
 
+type UnwindResult<T> = Result<T, Box<dyn std::any::Any + Send>>;
+
 /// A task group with the following properties:
 ///
 /// - No more than a certain number of tasks are ever active at once.
@@ -22,7 +24,7 @@ use tokio_util::sync::{CancellationToken, DropGuard};
 /// - Dropping `BoundedTreeNursery` causes all tasks to be aborted.
 #[derive(Debug)]
 pub(crate) struct BoundedTreeNursery<T> {
-    receiver: UnboundedReceiver<T>,
+    receiver: UnboundedReceiver<UnwindResult<T>>,
     _on_drop: DropGuard,
 }
 
@@ -61,7 +63,11 @@ impl<T: 'static> Stream for BoundedTreeNursery<T> {
     ///
     /// If a task panics, this method resumes unwinding the panic.
     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
-        self.receiver.poll_recv(cx)
+        match ready!(self.receiver.poll_recv(cx)) {
+            Some(Ok(r)) => Some(r).into(),
+            Some(Err(e)) => std::panic::resume_unwind(e),
+            None => None.into(),
+        }
     }
 }
 
@@ -69,7 +75,7 @@ impl<T: 'static> Stream for BoundedTreeNursery<T> {
 #[derive(Debug)]
 pub(crate) struct Spawner<T> {
     semaphore: Arc<Semaphore>,
-    sender: UnboundedSender<T>,
+    sender: UnboundedSender<UnwindResult<T>>,
     token: CancellationToken,
 }
 
@@ -115,7 +121,7 @@ impl<T: Send + 'static> Spawner<T> {
         tokio::spawn(async move {
             tokio::select!(
                 () = token.cancelled() => (),
-                r = fut => {
+                r = std::panic::AssertUnwindSafe(fut).catch_unwind() => {
                     let _ = sender.send(r);
                 }
             );

I have concerns about the correctness of using AssertUnwindSafe here, but they just echo this thread from two years ago.