Last week, I asked for a code review of a "bounded tree nursery" type I had written and got some pointers to an alternative approach in response. I've followed the general advice given there, with some modifications (see below).
As before, the goal of the code is to provide a task group that:
-
Limits the number of tasks active at once
-
Returns a stream of task return values
-
Aborts all tasks when the stream is dropped
-
Gives each task a way to spawn more tasks in the same group (giving rise to a tree of tasks, hence the naming)
@kpreid's suggestions involved creating a stream of JoinHandle
s returned from tokio::spawn()
and batch-awaiting them using StreamExt::buffer_unordered()
. However, under this approach, the buffer_unordered()
has no limiting effect on the number of active tasks, only on the number of handles polled at once. I thus had to reintroduce the original code's approach of wrapping each task in a semaphore, at which point buffer_unordered()
did more harm (due to not polling all spawned handles) than good.
My solution was to replace buffer_unordered()
with a FuturesUnordered
. At first, I wrapped this in an Arc
that was cloned among Spawner
s, but when it came time to poll the FuturesUnordered
, a mutable reference was needed, which meant I'd have to add a mutex, something I'm trying to get away from in this version of the code. Instead, I ended up giving each Spawner
an UnboundedSender
(unbounded so that Sender::spawn()
could be sync) for sending JoinHandle
s (in an abort-on-drop wrapper), with the BoundedTreeNursery
transferring all received handles to the FuturesUnordered
on each poll.
Aside from a need for general code review advice, I'm concerned about the polling of the receiver alongside the polling of the FuturesUnordered
. Is this the right thing to do? In particular, I had no idea what value to use for the poll_recv_many()
limit, so I just picked 32 out of thin air. Would it be better for the limit to be the same value as (or a function of) the limit
passed to BoundedTreeNursery::new()
?
use futures_util::{stream::FuturesUnordered, Stream, StreamExt};
use pin_project_lite::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Context, Poll};
use tokio::{
sync::{
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
Semaphore,
},
task::{JoinError, JoinHandle},
};
/// A task group with the following properties:
///
/// - No more than a certain number of tasks are ever active at once.
///
/// - Each task is passed a `Spawner` that can be used to spawn more tasks in
/// the group.
///
/// - `BoundedTreeNursery<T>` is a `Stream` of the return values of the tasks
/// (which must all be `T`).
///
/// - Dropping `BoundedTreeNursery` causes all tasks to be aborted.
#[derive(Debug)]
pub(crate) struct BoundedTreeNursery<T> {
receiver: UnboundedReceiver<FragileHandle<T>>,
tasks: FuturesUnordered<FragileHandle<T>>,
closed: bool,
}
impl<T: Send + 'static> BoundedTreeNursery<T> {
/// Create a `BoundedTreeNursery` that limits the number of active tasks to
/// at most `limit` and with `root` spawned as the initial task
pub(crate) fn new<F, Fut>(limit: usize, root: F) -> Self
where
F: FnOnce(Spawner<T>) -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
{
let semaphore = Arc::new(Semaphore::new(limit));
let (sender, receiver) = unbounded_channel();
let spawner = Spawner { semaphore, sender };
spawner.spawn_with_self(root);
BoundedTreeNursery {
tasks: FuturesUnordered::new(),
receiver,
closed: false,
}
}
}
impl<T: Send + 'static> Stream for BoundedTreeNursery<T> {
type Item = T;
/// Poll for one of the tasks in the group to complete and return its
/// return value.
///
/// # Panics
///
/// If a task panics, this method resumes unwinding the panic.
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
let mut buf = Vec::new();
match self.receiver.poll_recv_many(cx, &mut buf, 32) {
Poll::Pending => (),
Poll::Ready(0) => self.closed = true,
Poll::Ready(_) => self.tasks.extend(buf),
}
match ready!(self.tasks.poll_next_unpin(cx)) {
Some(Ok(r)) => Some(r).into(),
Some(Err(e)) => match e.try_into_panic() {
Ok(barf) => std::panic::resume_unwind(barf),
Err(e) => unreachable!(
"Task in BoundedTreeNursery should not have been aborted, but got {e:?}"
),
},
None => {
if self.closed {
// All spawners dropped and all results yielded; end of
// stream
None.into()
} else {
Poll::Pending
}
}
}
}
}
/// A handle for spawning tasks in a `BoundedTreeNursery<T>`
#[derive(Debug)]
pub(crate) struct Spawner<T> {
semaphore: Arc<Semaphore>,
sender: UnboundedSender<FragileHandle<T>>,
}
// Clone can't be derived, as that would erroneously add `T: Clone` bounds to
// the impl.
impl<T> Clone for Spawner<T> {
fn clone(&self) -> Spawner<T> {
Spawner {
semaphore: self.semaphore.clone(),
sender: self.sender.clone(),
}
}
}
impl<T: Send + 'static> Spawner<T> {
/// Spawn the given task in the task group, passing it a new `Spawner`
pub(crate) fn spawn<F, Fut>(&self, func: F)
where
F: FnOnce(Spawner<T>) -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
{
self.clone().spawn_with_self(func);
}
/// Spawn the given task in the task group, passing it this `Spawner`
fn spawn_with_self<F, Fut>(self, func: F)
where
F: FnOnce(Spawner<T>) -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
{
let semaphore = self.semaphore.clone();
let sender = self.sender.clone();
let _ = sender.send(FragileHandle::new(tokio::spawn(async move {
let Ok(_permit) = semaphore.acquire().await else {
unreachable!("Semaphore should not be closed");
};
func(self).await
})));
}
}
pin_project! {
/// A wrapper around `tokio::task::JoinHandle` that aborts the task on drop.
#[derive(Debug)]
struct FragileHandle<T> {
#[pin]
inner: JoinHandle<T>
}
impl<T> PinnedDrop for FragileHandle<T> {
fn drop(this: Pin<&mut Self>) {
this.project().inner.abort();
}
}
}
impl<T> FragileHandle<T> {
fn new(inner: JoinHandle<T>) -> Self {
FragileHandle { inner }
}
}
impl<T> Future for FragileHandle<T> {
type Output = Result<T, JoinError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.inner.poll(cx)
}
}
A partially-stubbed example of how this can be used to traverse a remote file
hierarchy:
use futures_util::{future::BoxFuture, FutureExt, TryStreamExt};
use std::fmt;
use url::Url;
async fn traverse(base_url: Url, workers: usize) -> anyhow::Result<()> {
let client = Client;
let mut stream = BoundedTreeNursery::new(workers, move |spawner| {
process_dir(spawner, client, base_url)
});
while let Some(r) = stream.try_next().await? {
println!("{r}");
}
Ok(())
}
fn process_dir(
spawner: Spawner<anyhow::Result<Report>>,
client: Client,
url: Url,
) -> BoxFuture<'static, anyhow::Result<Report>> {
// We need to return a boxed Future in order to be able to call
// `process_dir()` inside itself.
async move {
let dl = client.list_directory(url).await?;
for d in dl.directories {
let cl2 = client.clone();
spawner.spawn(move |spawner| Box::pin(process_dir(spawner, cl2, d)));
}
for f in dl.files {
let cl2 = client.clone();
spawner.spawn(move |_spawner| process_file(cl2, f));
}
Ok(Report)
}
.boxed()
}
async fn process_file(client: Client, url: Url) -> anyhow::Result<Report> {
client.get_file(url).await
}
/// Stub client for traversing a web hierarchy of some sort (e.g., WebDAV)
#[derive(Clone, Debug)]
struct Client;
impl Client {
async fn list_directory(&self, url: Url) -> anyhow::Result<DirectoryListing<Url>> {
todo!()
}
async fn get_file(&self, url: Url) -> anyhow::Result<Report> {
todo!()
}
}
struct DirectoryListing {
directories: Vec<Url>,
files: Vec<Url>,
}
/// Information returned for each request
struct Report;
impl fmt::Display for Report {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
todo!()
}
}