Would someone kindly tell me if this is safe? I'm trying to split a stream into two streams by a predicate function. The code seems to work just fine, so I'm hoping that means that the logic is correct. However, I was reading up on Waker
s and saw that they might move in some instances which is why the waker needs to be checked if it needs to be updated on every poll, but I wasn't clear on how that interacts with having two streams having access to each other's waker.
Any feedback would be appreciated.
use futures::stream::{poll_fn, Stream, StreamExt};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Poll, Waker};
fn split_stream_by<I, S, P>(
stream: S,
predicate: P,
) -> (impl Stream<Item = I>, impl Stream<Item = I>)
where
S: Stream<Item = I> + Unpin,
P: Fn(&I) -> bool,
{
struct SplitBy<I, S, P> {
buf_true: Option<I>,
buf_false: Option<I>,
waker_true: Option<Waker>,
waker_false: Option<Waker>,
stream: S,
predicate: P,
}
let split_by = SplitBy {
buf_true: None,
buf_false: None,
waker_true: None,
waker_false: None,
stream,
predicate,
};
let split_by = Arc::new(Mutex::new(split_by));
let true_stream = poll_fn({
let split_by = split_by.clone();
move |cx| {
if let Ok(mut split_by) = split_by.try_lock() {
let SplitBy {
buf_true,
buf_false,
waker_true,
waker_false,
stream,
predicate,
} = &mut *split_by;
if let Some(waker) = waker_true {
if !waker.will_wake(cx.waker()) {
*waker = cx.waker().clone();
}
} else {
*waker_true = Some(cx.waker().clone());
}
if let Some(item) = buf_true.take() {
return Poll::Ready(Some(item));
}
if buf_false.is_some() {
// There is a value available for the other stream. Wake that stream if possible
// and return pending since we can't store multiple values for a stream
if let Some(waker) = waker_false {
waker.wake_by_ref();
}
return Poll::Pending;
}
match Pin::new(stream).poll_next(cx) {
Poll::Ready(Some(item)) => {
if (predicate)(&item) {
Poll::Ready(Some(item))
} else {
// This value is not what we wanted. Store it and notify other partition task if
// it exists
let _ = buf_false.replace(item);
if let Some(waker) = waker_false {
waker.wake_by_ref();
}
Poll::Pending
}
}
Poll::Ready(None) => {
// If the underlying stream is finished, the `false` stream also must be
// finished, so wake it in case nothing else polls it
if let Some(waker) = waker_false {
waker.wake_by_ref();
}
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
} else {
cx.waker().wake_by_ref();
std::task::Poll::Pending
}
}
});
let false_stream = futures::stream::poll_fn({
let split_by = split_by.clone();
move |cx| {
if let Ok(mut split_by) = split_by.try_lock() {
let SplitBy {
buf_true,
buf_false,
waker_true,
waker_false,
stream,
predicate,
} = &mut *split_by;
if let Some(waker) = waker_false {
if !waker.will_wake(cx.waker()) {
*waker = cx.waker().clone();
}
} else {
*waker_false = Some(cx.waker().clone());
}
if let Some(item) = buf_false.take() {
return Poll::Ready(Some(item));
}
if buf_true.is_some() {
// There is a value available for the other stream. Wake that stream if possible
// and return pending since we can't store multiple values for a stream
if let Some(waker) = waker_true {
waker.wake_by_ref();
}
return Poll::Pending;
}
match Pin::new(stream).poll_next(cx) {
Poll::Ready(Some(item)) => {
if (predicate)(&item) {
// This value is not what we wanted. Store it and notify other partition task if
// it exists
let _ = buf_true.replace(item);
if let Some(waker) = waker_true {
waker.wake_by_ref();
}
Poll::Pending
} else {
Poll::Ready(Some(item))
}
}
Poll::Ready(None) => {
// If the underlying stream is finished, the `true` stream also must be
// finished, so wake it in case nothing else polls it
if let Some(waker) = waker_true {
waker.wake_by_ref();
}
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
} else {
cx.waker().wake_by_ref();
std::task::Poll::Pending
}
}
});
(true_stream, false_stream)
}