Passing a Sender to an async procedure and streaming the sent values

As part of vague plans for a larger project, I've written the following functions and Stream for working with tokio's mpsc channels. Each function takes a FnOnce(/*Unbounded*/Sender<T>) -> impl Future<Output = ()> and returns a Stream that runs the future to completion while yielding the values passed through the channel.

use self::inner::Receiver as _;
use futures::future::{maybe_done, MaybeDone};
use futures::stream::Stream;
use pin_project_lite::pin_project;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::mpsc::{
    channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender,
};

/// `received_stream()` takes a buffer size and an async procedure that takes a
/// [`tokio::sync::mpsc::Sender`], and it returns a stream that runs the
/// procedure to completion while yielding the values passed to the sender.
///
/// If the stream is dropped before completion, the async procedure (which may
/// or may not have completed by that point) is dropped as well.
pub fn received_stream<F, Fut, T>(buffer: usize, f: F) -> ReceivedStream<Fut, T, Receiver<T>>
where
    F: FnOnce(Sender<T>) -> Fut,
    Fut: Future<Output = ()>,
{
    let (sender, receiver) = channel(buffer);
    let future = f(sender);
    ReceivedStream::new(future, receiver)
}

/// `unbounded_received_stream()` takes an async procedure that takes a
/// [`tokio::sync::mpsc::UnboundedSender`], and it returns a stream that runs
/// the procedure to completion while yielding the values passed to the sender.
///
/// If the stream is dropped before completion, the async procedure (which may
/// or may not have completed by that point) is dropped as well.
pub fn unbounded_received_stream<F, Fut, T>(f: F) -> ReceivedStream<Fut, T, UnboundedReceiver<T>>
where
    F: FnOnce(UnboundedSender<T>) -> Fut,
    Fut: Future<Output = ()>,
{
    let (sender, receiver) = unbounded_channel();
    let future = f(sender);
    ReceivedStream::new(future, receiver)
}

pin_project! {
    pub struct ReceivedStream<Fut, T, Recv> where Fut: Future {
        #[pin]
        future: MaybeDone<Fut>,
        receiver: inner::MaybeAllReceived<Recv>,
        _item: PhantomData<T>,
    }
}

impl<Fut: Future, T, Recv> ReceivedStream<Fut, T, Recv> {
    fn new(future: Fut, receiver: Recv) -> Self {
        ReceivedStream {
            future: maybe_done(future),
            receiver: inner::MaybeAllReceived::InProgress(receiver),
            _item: PhantomData,
        }
    }
}

impl<Fut, T, Recv> Stream for ReceivedStream<Fut, T, Recv>
where
    Fut: Future<Output = ()>,
    Recv: inner::Receiver<Item = T>,
{
    type Item = T;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
        let this = self.project();
        let fut_poll = this.future.poll(cx).map(|_| None);
        let recv_poll = this.receiver.poll_next_recv(cx);
        if recv_poll.is_pending() {
            fut_poll
        } else {
            recv_poll
        }
    }
}

mod inner {
    use std::task::{Context, Poll};
    use tokio::sync::mpsc;

    pub(super) trait Receiver {
        type Item;

        fn poll_next_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>;
    }

    impl<T> Receiver for mpsc::Receiver<T> {
        type Item = T;

        fn poll_next_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
            self.poll_recv(cx)
        }
    }

    impl<T> Receiver for mpsc::UnboundedReceiver<T> {
        type Item = T;

        fn poll_next_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
            self.poll_recv(cx)
        }
    }

    pub(super) enum MaybeAllReceived<Recv> {
        InProgress(Recv),
        Done,
    }

    impl<Recv: Receiver> Receiver for MaybeAllReceived<Recv> {
        type Item = <Recv as Receiver>::Item;

        fn poll_next_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
            match self {
                MaybeAllReceived::InProgress(recv) => {
                    let p = recv.poll_next_recv(cx);
                    if matches!(p, Poll::Ready(None)) {
                        *self = MaybeAllReceived::Done;
                    }
                    p
                }
                MaybeAllReceived::Done => Poll::Ready(None),
            }
        }
    }
}

The code, with tests, on the Playground

My primary concerns are:

  • Does this correctly and efficiently do what I want?
  • Is using a public-in-private trait the best way to support both bounded & unbounded channels?
  • Would there be any value in spawning the future as a task instead of polling it concurrently with the receiver?
  • Ideas for more thorough tests?

Any other critiques are welcome as well.

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.