How to connect a MPSC tokio channel with a fallible stream consumer?

I have a bi-directional stream gRPC endpoint (i.e. GCP pub/sub streaming_pull). The inbound stream is a stream of pub/sub messages and the outbound stream is a stream of ack ids. In my current implementation ack ids are sent to a MPSC tokio channel and I wrap the receiver into a ReceiverStream to connect this channel with the gRPC call (outbound parameter). For ease of understanding, I have extracted the code into a small code snippet

use std::time::Duration;
use tokio::join;
use futures_util::{stream, StreamExt};
use futures_core::Stream;
use std::pin::Pin;
use std::sync::{Arc};
use tokio::sync::Mutex;
use tokio::sync::mpsc::Receiver;
use tokio::task::JoinHandle;
use tokio_stream::wrappers::ReceiverStream;

#[tokio::main]
async fn main() {
    let (sender, receiver) = tokio::sync::mpsc::channel(1000);

    let task = channel_processor(receiver).await;

    let mut counter = 0;
    loop {
        sender.send(format!("ack_{}", counter)).await;
        counter += 1;
        tokio::time::sleep(Duration::from_millis(100)).await;
    }

    join!(task);
}

async fn channel_processor(receiver: Receiver<String>) -> JoinHandle<()> {
    tokio::spawn(async move {
        let mut receiver_stream = ReceiverStream::new(receiver);
        stream_consumer(Box::pin(receiver_stream)).await;
    })
}

// 
async fn streaming_pull(mut stream: Pin<Box<dyn Stream<Item=String> + Send + Sync + 'static>>)  {
    // connection to server
}

This works fine until there is a network failure or a connection timeout. Ideally I'd like to call again the stream_consumer method with the same stream to continue the processing. Unfortunately the stream is consumed by the streaming_pull method and I can't figure out how to get around that (knowing that the receiver can't be cloned). Any help will be appreciated.

Consider having a mechanism to send it a new sender when you want to change the receiver. Alternatively, wrap the receiver in a custom object that will send the receiver back to you via an oneshot channel in its destructor. Or maybe don't even give up ownership of the receiver in the first place, giving out only a mutable borrow.

Thank you for the quick response. The first option is not possible as I have to keep the unprocessed ack-ids in the channel. I'll try the second solution because I don't see how to avoid giving the ownership to ReceiverStream. I am however interested in understanding a little more about this third approach.

Finally find the time to implement a solution based on oneshot and std::mem::take. It's not perfect because I didn't find a way to encode with the type system the fact that the call to the stream method must not be done twice in the same scope. Please tell me what you think.

use std::time::Duration;
use tokio::join;
use futures_util::{StreamExt};
use futures_core::Stream;
use std::pin::Pin;
use tokio::sync::mpsc::Receiver;
use tokio::task::JoinHandle;
use std::task::{Context, Poll};
use tokio::sync::oneshot::error::TryRecvError;

#[tokio::main]
async fn main() {
    let (sender, receiver) = tokio::sync::mpsc::channel(1000);

    let task = channel_processor(receiver).await;

    let mut counter = 0;
    loop {
        sender.send(format!("ack_id_{}", counter)).await;
        counter += 1;
        tokio::time::sleep(Duration::from_millis(100)).await;
    }

    join!(task);
}

async fn channel_processor(receiver: Receiver<String>) -> JoinHandle<()> {
    tokio::spawn(async move {
        let mut reusable_receiver = ReusableReceiver::new(receiver);

        while let Err(error) = fallible_stream_consumer(reusable_receiver.stream()).await {
            // Do something with the error, i.e. logging
        }
    })
}

async fn fallible_stream_consumer(mut stream: Pin<Box<dyn Stream<Item=String> + Send + Sync + 'static>>) -> Result<(), &str> {
    let mut countdown_before_failure = 10;

    loop {
        if let Some(value) = stream.next().await {
            println!("process {:?}", value);

            countdown_before_failure -= 1;
            if countdown_before_failure == 0 {
                return Err("SomeError");
            }
        } else {
            return Ok(());
        }
    }
}


pub struct ReusableReceiver<T> {
    receiver: Option<Receiver<T>>,
    receiver_oneshot: Option<tokio::sync::oneshot::Receiver<Receiver<T>>>,
}

pub struct ReusableReceiverStream<T> {
    receiver: Option<Receiver<T>>,
    sender_oneshot: Option<tokio::sync::oneshot::Sender<Receiver<T>>>,
}

impl<T> ReusableReceiver<T> where T: Sync + Send + 'static {
    pub fn new(receiver: Receiver<T>) -> Self {
        Self {
            receiver: Some(receiver),
            receiver_oneshot: None,
        }
    }

    pub fn stream(&mut self) -> Pin<Box<dyn Stream<Item=T> + Send + Sync + 'static>> {
        if self.receiver.is_none() {
            self.recover_receiver();
        }

        let (sender_oneshot, receiver_oneshot) = tokio::sync::oneshot::channel();

        self.receiver_oneshot = Some(receiver_oneshot);

        Box::pin(ReusableReceiverStream {
            receiver: std::mem::take(&mut self.receiver),
            sender_oneshot: Some(sender_oneshot),
        })
    }

    fn recover_receiver(&mut self) {
        let mut receiver_oneshot = std::mem::take(&mut self.receiver_oneshot)
            .expect("unexpected situation, need to be fixed");

        loop {
            match receiver_oneshot.try_recv() {
                Err(TryRecvError::Closed) => {
                    return;
                }
                Err(TryRecvError::Empty) => {}
                Ok(receiver) => {
                    self.receiver = Some(receiver);
                    return;
                }
            }
        }
    }
}

impl<T> Drop for ReusableReceiverStream<T> {
    fn drop(&mut self) {
        let receiver = std::mem::take(&mut self.receiver)
            .expect("unexpected situation, need to be fixed");
        let sender_oneshot = std::mem::take(&mut self.sender_oneshot);
        let result = sender_oneshot
            .expect("unexpected situation, need to be fixed")
            .send(receiver);
        if let Err(error) = result {
            println!("{:?}", error);
        }
    }
}

impl<T> Stream for ReusableReceiverStream<T> {
    type Item = T;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        self.receiver
            .as_mut()
            .expect("unexpected situation, need to be fixed")
            .poll_recv(cx)
    }
}
1 Like

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.