Async Read/Write trait objects not closing on connection dropping

Hi everyone,

To give some context to my issue:

  • I'm abstracting an async-std TcpStream or TlsStream as a split reader (Box<dyn AsyncRead + Send + Sync + Unpin>) and writer (Box<dyn AsyncWrite + Send + Sync + Unpin>)
  • I'm wrapping the reader trait object in a struct with a Stream impl, and the writer trait object in a struct with a Sink impl

I've run into an issue, for example, when the underlying Tcp connection is closed by terminating the server program, the client program does not notice the connection is closed. Specifically, the Stream impl on the reader does not close, and still awaits on the read async call.

Below is the simplified Stream impl for the struct containing the AsyncRead trait object, as read_stream.

impl Stream for ConnectionReader {
    type Item = Any;

    fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let mut buffer = BytesMut::new();
        buffer.resize(BUFFER_SIZE, 0);
        // do some prep stuff

        loop {
            trace!("Reading from the stream");
            match futures::executor::block_on(self.read_stream.read(&mut buffer)) {
                Ok(mut bytes_read) => {
                    if bytes_read > 0 {
                        // do some parsing stuff with the read bytes
                        Poll::Ready(Some(data));
                    } else {
                        Poll::Pending;
                    }
                }

                // Close the stream
                Err(_err) => {
                    trace!("Closing the stream");
                    // flush pending stuff
                    return Poll::Ready(None);
                }
            }
        }
    }
}

As I mentioned above, the code just awaits the result of the read async call, which never returns an Err or anything despite the connection being closed.

How do I fix this bug? How should I rework the Stream impl to avoid this issue?

Thanks in advance!

You should not call futures::executor::block_on from inside a poll method, because this blocks the thread. You cannot rely on IO notifications if you don't yield back to the executor.

Instead you should be calling the poll_read method, properly handling when it returns Poll::Pending by forwarding the Poll::Pending to the caller. Your poll method will then be called again whenever the IO resource has an event.

Check out Async in depth and Streams from the Tokio tutorial. Consider also the async-stream crate, which (unlike poll methods!) lets you write ordinary imperative async code and turn it into a stream.

1 Like

You need to recombine the sink/stream into a TcpStream, then call poll_shutdown to have the FIN packet sent down at the TCP layer. Here's some of my code I use to wrap a Framed object that automatically closes the stream on drop (you can use wireshark to verify):

Wrap your Framed via CleanFramedShutdown::wrap

use parking_lot::Mutex;
use futures::stream::{SplitSink, SplitStream};
use std::sync::Arc;
use std::pin::Pin;
use tokio::prelude::{AsyncWrite, AsyncRead};
use tokio_util::codec::{Framed, Encoder, Decoder};
use std::ops::{Deref, DerefMut};
use futures::StreamExt;

struct CleanFramedShutdownInner<S, U, I> {
    sink: Option<SplitSink<Framed<S, U>, I>>,
    stream: Option<SplitStream<Framed<S, U>>>
}

pub struct CleanFramedShutdown<S, U, I> {
    inner: Arc<Mutex<CleanFramedShutdownInner<S, U, I>>>
}

pub struct CleanShutdownSink<S: AsyncWrite + AsyncRead + Unpin + 'static, U: Encoder<I, Error: From<std::io::Error>> + Decoder + 'static, I: 'static> {
    ptr: CleanFramedShutdown<S, U, I>,
    inner: Option<SplitSink<Framed<S, U>, I>>
}

pub struct CleanShutdownStream<S: AsyncWrite + AsyncRead + Unpin + 'static, U: Encoder<I, Error: From<std::io::Error>> + Decoder + 'static, I: 'static> {
    ptr: CleanFramedShutdown<S, U, I>,
    inner: Option<SplitStream<Framed<S, U>>>
}

impl<S: AsyncWrite + AsyncRead + Unpin + 'static, U: Encoder<I, Error: From<std::io::Error>> + Decoder + 'static, I: 'static> CleanShutdownSink<S, U, I> {
    pub fn new(inner: SplitSink<Framed<S, U>, I>, ptr: CleanFramedShutdown<S, U, I>) -> Self {
        Self { inner: Some(inner), ptr }
    }
}

impl<S: AsyncWrite + AsyncRead + Unpin + 'static, U: Encoder<I, Error: From<std::io::Error>> + Decoder + 'static, I: 'static> Drop for CleanShutdownSink<S, U, I> {
    #[allow(unused_results)]
    fn drop(&mut self) {
        let inner = self.inner.take().unwrap();
        let ptr = self.ptr.clone();
        tokio::task::spawn(ptr.push_sink(inner));
    }
}

impl<S: AsyncWrite + AsyncRead + Unpin + 'static, U: Encoder<I, Error: From<std::io::Error>> + Decoder + 'static, I: 'static> CleanShutdownStream<S, U, I> {
    pub fn new(inner: SplitStream<Framed<S, U>>, ptr: CleanFramedShutdown<S, U, I>) -> Self {
        Self { inner: Some(inner), ptr }
    }
}

impl<S: AsyncWrite + AsyncRead + Unpin + 'static, U: Encoder<I, Error: From<std::io::Error>> + Decoder + 'static, I: 'static> Drop for CleanShutdownStream<S, U, I> {
    #[allow(unused_results)]
    fn drop(&mut self) {
        let inner = self.inner.take().unwrap();
        let ptr = self.ptr.clone();
        tokio::task::spawn(ptr.push_stream(inner));
    }
}

impl<S: AsyncWrite + AsyncRead + Unpin + 'static, U: Encoder<I, Error: From<std::io::Error>> + Decoder + 'static, I: 'static> Deref for CleanShutdownSink<S, U, I> {
    type Target = SplitSink<Framed<S, U>, I>;

    fn deref(&self) -> &Self::Target {
        self.inner.as_ref().unwrap()
    }
}

impl<S: AsyncWrite + AsyncRead + Unpin + 'static, U: Encoder<I, Error: From<std::io::Error>> + Decoder + 'static, I: 'static> DerefMut for CleanShutdownSink<S, U, I> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        self.inner.as_mut().unwrap()
    }
}

impl<S: AsyncWrite + AsyncRead + Unpin + 'static, U: Encoder<I, Error: From<std::io::Error>> + Decoder + 'static, I: 'static> Deref for CleanShutdownStream<S, U, I> {
    type Target = SplitStream<Framed<S, U>>;

    fn deref(&self) -> &Self::Target {
        self.inner.as_ref().unwrap()
    }
}

impl<S: AsyncWrite + AsyncRead + Unpin + 'static, U: Encoder<I, Error: From<std::io::Error>> + Decoder + 'static, I: 'static> DerefMut for CleanShutdownStream<S, U, I> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        self.inner.as_mut().unwrap()
    }
}

impl<S: AsyncWrite + AsyncRead + Unpin + 'static, U: Encoder<I, Error: From<std::io::Error>> + Decoder + 'static, I: 'static> CleanFramedShutdown<S, U, I> {
    pub fn new() -> Self {
        let inner = CleanFramedShutdownInner { sink: None, stream: None };
        CleanFramedShutdown { inner: Arc::new(Mutex::new(inner)) }
    }

    pub fn wrap(framed: Framed<S, U>) -> (CleanShutdownSink<S, U, I>, CleanShutdownStream<S, U, I>) {
        let ptr = Self::new();
        let (sink, stream) = framed.split();
        let sink = CleanShutdownSink::new(sink, ptr.clone());
        let stream = CleanShutdownStream::new(stream, ptr);
        (sink, stream)
    }

    pub async fn push_sink(self, sink: SplitSink<Framed<S, U>, I>) {
        let mut this = self.inner.lock();
        if let Some(stream) = this.stream.take() {
            let reunited = stream.reunite(sink);
            if let Ok(reunited) = reunited {
                Self::reunite_fn(reunited).await
            }
        } else {
            this.sink = Some(sink);
        }
    }

    pub async fn push_stream(self, stream: SplitStream<Framed<S, U>>) {
        let mut this = self.inner.lock();
        if let Some(sink) = this.sink.take() {
            let reunited = sink.reunite(stream);
            if let Ok(reunited) = reunited {
                Self::reunite_fn(reunited).await
            }
        } else {
            this.stream = Some(stream);
        }
    }

    #[allow(unused_must_use)]
    async fn reunite_fn(reunited: Framed<S, U>) {
        let mut reunited = reunited.into_inner();
        futures::future::poll_fn(move |cx| {
            Pin::new(&mut reunited).poll_shutdown(cx)
        }).await;
    }
}

impl<S, U, I> Clone for CleanFramedShutdown<S, U, I> {
    fn clone(&self) -> Self {
        Self { inner: self.inner.clone() }
    }
}
1 Like

In addition to the replies above: When the connection is closed and you have reached the end of the stream, AsyncRead::read will yield Ok(0). In your example code, when bytes_read is zero, you should handle connection closure and stop looping and reading.

For some related discussion, see How to detect TCP close?

1 Like

Just to be absolutely clear, the comment about bytes_read == 0 is also important to fix, but you cannot just fix that and ignore the block_on issue. Calling block_on in a poll method is always incorrect.

@nologik That snippet seems unnecessarily complicated.

2 posts were split to a new topic: Closing a split Framed

Thanks everyone!

Both the bytes_read and block_on issues make sense to me, and hopefully by resolving these the problem is addressed.

To provide an analogy to the blog post I linked about blocking the thread (here it is again), what I called "reaching an .await" in that blog post will in our case correspond to returning from poll_next. Thus, if the duration of time from poll_next being called to it returning is ever more than 10-100 microseconds, then you can be confident that your solution is incorrect.

I really liked these two points from the post especially:

  • Async code should never spend a long time without reaching an .await .
  • To give a sense of scale of how much time is too much, a good rule of thumb is no more than 10 to 100 microseconds between each .await .

"Time-to-await" isn't really something I was thinking about before and I'll have to comb through my code to change that going forward.

I often reach for block_on when I can't use await, like when I'm in a sync block that can't be made into an async block. A good example of this is the mistake you pointed out in my code above: I could not use await inside poll_next, so I used block_on. If I'm understanding correctly, if the calls to block_on also don't return within at most 100 microseconds, then something is wrong in the design of the async code being blocked on.

I'm curious though, how did you arrive at the numbers of 10-100 microseconds? What was the intuition behind this threshold and what use cases did you have in mind? Understanding this might help me put more faith behind this design principle.

Yeah. This is a common mistake. Besides that blog post, most of the things I've written to make sure people get off in the right direction with async are in the Tokio tutorial, but people keep finding the async-std book instead because of the project's more official sounding name that includes the word async.

With block_on in particular, I would consider it wrong even if it did return within 10 microseconds. I don't think async-std does this, but Tokio has a feature called automatic cooperative task yielding, which can cause deadlocks if combined with block_on, and even with async-std, libraries can have features that cause similar problems.

Ultimately the exact threshold is somewhat arbitrary. As you may suspect, it is actually a continuous scale where lower is better. The use case is the classical use case of async Rust, namely networking applications where you have a lot of tasks running concurrently to handle a lot of connections, and here, a time-to-await of a few milliseconds in each tasks can very quickly add up to give large tail latencies for the unlucky tasks that end up at the back of the queue.

1 Like

Thanks for such a thorough reply!