How can I implement fairness across multiple Tokio Receivers?

I have a Vec of Tokio mpsc Receivers, but I want to ensure fairness across them, ideally via a round robin approach.

How can I do this the right way? I would have thought there'd be an existing abstraction for this but couldn't find anything.

Something like this, maybe? (untested)

use futures::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::mpsc;

struct RoundRobin {
    receivers: Vec<mpsc::Receiver<i32>>,
    start_idx: usize,
}

impl Future for RoundRobin {
    type Output = Option<i32>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let RoundRobin {
            receivers,
            start_idx,
        } = &mut *self;

        if receivers.len() == 0 {
            return Poll::Ready(None);
        }

        let mut i = *start_idx;
        let mut last = start_idx.checked_sub(1).unwrap_or(receivers.len() - 1) % receivers.len();

        while i != last {
            let mut next = (i + 1) % receivers.len();

            match receivers[i].poll_recv(cx) {
                Poll::Pending => {}
                Poll::Ready(Some(x)) => {
                    *start_idx = next;
                    return Poll::Ready(Some(x));
                }
                Poll::Ready(None) => {
                    receivers.remove(i);

                    if receivers.len() == 0 {
                        return Poll::Ready(None);
                    }

                    last =
                        start_idx.checked_sub(1).unwrap_or(receivers.len() - 1) % receivers.len();
                    next = (i + 1) % receivers.len();
                }
            }

            i = next;
        }

        cx.waker().wake_by_ref();
        Poll::Pending
    }
}

Playground.

(Edit: forgot to handle an empty receiver vector)

(Edit 2: forgot to handle underflow of usize when computing last)

Thanks so much! A couple questions:

  1. I assume this could be easily adapted to impl Stream instead of Future since it's going to produce many values?
  2. I'm not super familiar with implementing Future / Stream directly, and I'm unclear about how the whole Waker thing works. I get that it's a mechanism for preventing the future from doing a hot loop, but how would it work here? I.e. is there some way this tells the runtime when to know to wake us up again based on the inner futures?

Also, I'm not understanding the indexing / modulo logic. Starting at 0 would produce an underflow wouldn't it? Why do modulo at all vs just iterating over all elements in order?

Why not just use a single receiver? Instead of creating new channels, you can clone the sender of an existing channel to get many senders to the same receiver. This will result in something fair.

I'm using a bounded channel, so if one of the Senders sends tons of messages then wouldn't that mean it will slow down the other Senders since the channel will be full and create backpressure?

I don't want one Sender to be able to influence the rate at which we process messages from others at all.

I'm sure it can be optimized to avoid the cost of a division vs the cost of branch prediction, but that's a common way to count up and down between 0 and N-1. I think it was written that way to better illustrate what it's doing, and probably to be safer.

  • (i + 1) % N produces i + 1 or goes back to 0 when it reaches N. It's functionally equivalent to if i < N - 1 { i + 1 } else { 0 } (except if i happens to be outside the range, which might happen here if a receiver is removed when i reaches the upper limit, but that'll reset it anyway).
  • (i + N - 1) % N does the same to count down (the expression in the code excerpt above can be simplified to avoid the checked_sub and unwrap_or, which is evaluated every time—I suppose that's the underflow that's bothering you?). It's functionally equivalent to if i > 0 { i - 1 } else { N - 1 }, except again if i happens to be outside the range.

EDIT: clarified the "outside the range"

You can combine an unbounded mpsc channel with multiple semaphores to implement a per-sender limit. You send an owned semaphore permit together with the message to implement the limit correctly.

(In fact, the bounded channel is implemented by combining an unbounded channel with a semaphore.)

1 Like

Oooh, awesome idea! Thank you!