Review Stream Splitting code

Would someone kindly tell me if this is safe? I'm trying to split a stream into two streams by a predicate function. The code seems to work just fine, so I'm hoping that means that the logic is correct. However, I was reading up on Wakers and saw that they might move in some instances which is why the waker needs to be checked if it needs to be updated on every poll, but I wasn't clear on how that interacts with having two streams having access to each other's waker.

Any feedback would be appreciated.

playground

use futures::stream::{poll_fn, Stream, StreamExt};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Poll, Waker};

fn split_stream_by<I, S, P>(
    stream: S,
    predicate: P,
) -> (impl Stream<Item = I>, impl Stream<Item = I>)
where
    S: Stream<Item = I> + Unpin,
    P: Fn(&I) -> bool,
{
    struct SplitBy<I, S, P> {
        buf_true: Option<I>,
        buf_false: Option<I>,
        waker_true: Option<Waker>,
        waker_false: Option<Waker>,
        stream: S,
        predicate: P,
    }
    let split_by = SplitBy {
        buf_true: None,
        buf_false: None,
        waker_true: None,
        waker_false: None,
        stream,
        predicate,
    };
    let split_by = Arc::new(Mutex::new(split_by));
    let true_stream = poll_fn({
        let split_by = split_by.clone();
        move |cx| {
            if let Ok(mut split_by) = split_by.try_lock() {
                let SplitBy {
                    buf_true,
                    buf_false,
                    waker_true,
                    waker_false,
                    stream,
                    predicate,
                } = &mut *split_by;
                if let Some(waker) = waker_true {
                    if !waker.will_wake(cx.waker()) {
                        *waker = cx.waker().clone();
                    }
                } else {
                    *waker_true = Some(cx.waker().clone());
                }
                if let Some(item) = buf_true.take() {
                    return Poll::Ready(Some(item));
                }
                if buf_false.is_some() {
                    // There is a value available for the other stream. Wake that stream if possible
                    // and return pending since we can't store multiple values for a stream
                    if let Some(waker) = waker_false {
                        waker.wake_by_ref();
                    }
                    return Poll::Pending;
                }
                match Pin::new(stream).poll_next(cx) {
                    Poll::Ready(Some(item)) => {
                        if (predicate)(&item) {
                            Poll::Ready(Some(item))
                        } else {
                            // This value is not what we wanted. Store it and notify other partition task if
                            // it exists
                            let _ = buf_false.replace(item);
                            if let Some(waker) = waker_false {
                                waker.wake_by_ref();
                            }
                            Poll::Pending
                        }
                    }
                    Poll::Ready(None) => {
                        // If the underlying stream is finished, the `false` stream also must be
                        // finished, so wake it in case nothing else polls it
                        if let Some(waker) = waker_false {
                            waker.wake_by_ref();
                        }
                        Poll::Ready(None)
                    }
                    Poll::Pending => Poll::Pending,
                }
            } else {
                cx.waker().wake_by_ref();
                std::task::Poll::Pending
            }
        }
    });
    let false_stream = futures::stream::poll_fn({
        let split_by = split_by.clone();
        move |cx| {
            if let Ok(mut split_by) = split_by.try_lock() {
                let SplitBy {
                    buf_true,
                    buf_false,
                    waker_true,
                    waker_false,
                    stream,
                    predicate,
                } = &mut *split_by;
                if let Some(waker) = waker_false {
                    if !waker.will_wake(cx.waker()) {
                        *waker = cx.waker().clone();
                    }
                } else {
                    *waker_false = Some(cx.waker().clone());
                }
                if let Some(item) = buf_false.take() {
                    return Poll::Ready(Some(item));
                }
                if buf_true.is_some() {
                    // There is a value available for the other stream. Wake that stream if possible
                    // and return pending since we can't store multiple values for a stream
                    if let Some(waker) = waker_true {
                        waker.wake_by_ref();
                    }
                    return Poll::Pending;
                }
                match Pin::new(stream).poll_next(cx) {
                    Poll::Ready(Some(item)) => {
                        if (predicate)(&item) {
                            // This value is not what we wanted. Store it and notify other partition task if
                            // it exists
                            let _ = buf_true.replace(item);
                            if let Some(waker) = waker_true {
                                waker.wake_by_ref();
                            }
                            Poll::Pending
                        } else {
                            Poll::Ready(Some(item))
                        }
                    }
                    Poll::Ready(None) => {
                        // If the underlying stream is finished, the `true` stream also must be
                        // finished, so wake it in case nothing else polls it
                        if let Some(waker) = waker_true {
                            waker.wake_by_ref();
                        }
                        Poll::Ready(None)
                    }
                    Poll::Pending => Poll::Pending,
                }
            } else {
                cx.waker().wake_by_ref();
                std::task::Poll::Pending
            }
        }
    });
    (true_stream, false_stream)
}

Generally the main challenge is that, as written here, wakeups are emitted to only one of the two streams, depending on which one was last polled. You would generally want both streams to be notified.

It would be a lot simpler to spawn a task and send messages.

1 Like

Another tricky situation is that if you get a lot of things going to the left stream, but you're only reading from the right stream, then you get unbounded memory usage for queuing unless you block one stream until the other is read.

1 Like

Thanks, that makes sense @alice per usual. Channels do seem like an easier approach to this. I guess re: your second comment, for both approaches you'd really need to make sure both channels/streams have active tasks that are polling to avoid those issues.

Well yes, channels don't remove these issues, but using two bounded channels sounds like a good solution to me.