I have a Vec of Tokio mpsc Receivers, but I want to ensure fairness across them, ideally via a round robin approach.
How can I do this the right way? I would have thought there'd be an existing abstraction for this but couldn't find anything.
I have a Vec of Tokio mpsc Receivers, but I want to ensure fairness across them, ideally via a round robin approach.
How can I do this the right way? I would have thought there'd be an existing abstraction for this but couldn't find anything.
Something like this, maybe? (untested)
use futures::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::mpsc;
struct RoundRobin {
receivers: Vec<mpsc::Receiver<i32>>,
start_idx: usize,
}
impl Future for RoundRobin {
type Output = Option<i32>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let RoundRobin {
receivers,
start_idx,
} = &mut *self;
if receivers.len() == 0 {
return Poll::Ready(None);
}
let mut i = *start_idx;
let mut last = start_idx.checked_sub(1).unwrap_or(receivers.len() - 1) % receivers.len();
while i != last {
let mut next = (i + 1) % receivers.len();
match receivers[i].poll_recv(cx) {
Poll::Pending => {}
Poll::Ready(Some(x)) => {
*start_idx = next;
return Poll::Ready(Some(x));
}
Poll::Ready(None) => {
receivers.remove(i);
if receivers.len() == 0 {
return Poll::Ready(None);
}
last =
start_idx.checked_sub(1).unwrap_or(receivers.len() - 1) % receivers.len();
next = (i + 1) % receivers.len();
}
}
i = next;
}
cx.waker().wake_by_ref();
Poll::Pending
}
}
(Edit: forgot to handle an empty receiver vector)
(Edit 2: forgot to handle underflow of usize
when computing last
)
Thanks so much! A couple questions:
Also, I'm not understanding the indexing / modulo logic. Starting at 0 would produce an underflow wouldn't it? Why do modulo at all vs just iterating over all elements in order?
Why not just use a single receiver? Instead of creating new channels, you can clone the sender of an existing channel to get many senders to the same receiver. This will result in something fair.
I'm using a bounded channel, so if one of the Senders sends tons of messages then wouldn't that mean it will slow down the other Senders since the channel will be full and create backpressure?
I don't want one Sender to be able to influence the rate at which we process messages from others at all.
I'm sure it can be optimized to avoid the cost of a division vs the cost of branch prediction, but that's a common way to count up and down between 0 and N-1. I think it was written that way to better illustrate what it's doing, and probably to be safer.
(i + 1) % N
produces i + 1
or goes back to 0 when it reaches N. It's functionally equivalent to if i < N - 1 { i + 1 } else { 0 }
(except if i
happens to be outside the range, which might happen here if a receiver is removed when i reaches the upper limit, but that'll reset it anyway).(i + N - 1) % N
does the same to count down (the expression in the code excerpt above can be simplified to avoid the checked_sub
and unwrap_or
, which is evaluated every time—I suppose that's the underflow that's bothering you?). It's functionally equivalent to if i > 0 { i - 1 } else { N - 1 }
, except again if i
happens to be outside the range.EDIT: clarified the "outside the range"
You can combine an unbounded mpsc channel with multiple semaphores to implement a per-sender limit. You send an owned semaphore permit together with the message to implement the limit correctly.
(In fact, the bounded channel is implemented by combining an unbounded channel with a semaphore.)
Oooh, awesome idea! Thank you!