Critique wanted [TlsListener impl]

In order to allow easy interchangeability between using TcpListener and TlsAcceptor, I needed to add a wrapper around TlsAcceptor that allows it to be polled like a stream. In my tests, the wrapper appears to work, but I wanted the critique of the async experts here (you know who you are!) in case there's a better way of doing this. Specifically, visit the Stream implementation for TlsListener

pub struct TlsListener {
    inner: TcpListener,
    tls_acceptor: Arc<TlsAcceptor>,
    queue: Vec<Pin<Box<dyn Future<Output=Result<TlsStream<TcpStream>, tokio_native_tls::native_tls::Error>>>>>
}

impl TlsListener {
    pub fn new(inner: TcpListener, identity: Identity) -> std::io::Result<Self> {
        Ok(Self { inner, tls_acceptor: Arc::new(TlsAcceptor::from(tokio_native_tls::native_tls::TlsAcceptor::new(identity).map_err(|err| std::io::Error::new(std::io::ErrorKind::ConnectionRefused, err))?)), queue: Vec::new()})
    }

    pub fn new_pkcs<P: AsRef<Path>, T: AsRef<str>>(path: P, password: T, inner: TcpListener) -> std::io::Result<Self> {
        let identity = Self::load_tls_pkcs(path, password).map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err.into_string()))?;
        Self::new(inner, identity)
    }

    /// Given a path and password, returns the asymmetric crypto identity
    pub fn load_tls_pkcs<P: AsRef<Path>, T: AsRef<str>>(path: P, password: T) -> Result<Identity, NetworkError> {
        let bytes = std::fs::read(path).map_err(|err| NetworkError::Generic(err.to_string()))?;
        Identity::from_pkcs12(&bytes, password.as_ref()).map_err(|err| NetworkError::Generic(err.to_string()))
    }

    fn poll_future(future: &mut Pin<Box<dyn Future<Output=Result<TlsStream<TcpStream>, tokio_native_tls::native_tls::Error>>>>, cx: &mut Context<'_>) -> Poll<Option<<Self as Stream>::Item>> {
        future.as_mut().poll(cx).map(|r| Some(r.map(|stream| {
            let peer_addr = stream.get_ref().get_ref().get_ref().peer_addr().unwrap();
            (stream, peer_addr)
        }).map_err(|err| std::io::Error::new(std::io::ErrorKind::ConnectionRefused, err))))
    }
}

impl Stream for TlsListener {
    type Item = std::io::Result<(TlsStream<TcpStream>, SocketAddr)>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let Self {
            inner,
            tls_acceptor,
            queue
        } = &mut *self;

        // on first poll, there won't be any futures loaded. We skip this immediately on first-poll
        for (idx, future) in queue.iter_mut().enumerate() {
            return match futures::ready!(Self::poll_future(future, cx)) {
                Some(Ok((stream, peer_addr))) => {
                    let _ = queue.remove(idx);
                    Poll::Ready(Some(Ok((stream, peer_addr))))
                }

                Some(Err(err)) => {
                    let _ = queue.remove(idx);
                    Poll::Ready(Some(Err(err)))
                }

                None => {
                    let _ = queue.remove(idx);
                    log::error!("TlsListener: Polled none");
                    Poll::Ready(Some(Err(std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "Polled none"))))
                }
            }
        }

        // Now, poll the TcpListener
        match futures::ready!(Pin::new(inner).poll_accept(cx)) {
            Ok((stream, _peer_addr)) => {
                // We received a TcpStream. Upgrade it to a TlsStream by storing a pollable future
                let tls_acceptor = tls_acceptor.clone();
                let future = async move {
                    tls_acceptor.accept(stream).await
                };
                let mut future = Box::pin(future) as Pin<Box<dyn Future<Output=Result<TlsStream<TcpStream>, tokio_native_tls::native_tls::Error>>>>;
                // poll the future once to register any internal wakers
                let poll_res = Self::poll_future(&mut future, cx);
                queue.push(future);
                poll_res
            }

            Err(err) => {
                log::error!("TLS Listener error: {:?}", err);
                Poll::Ready(None)
            }
        }
    }
}

AFAICT you never store more than one future in your queue; it's only pushed to when the for loop runs to completion, but that never occurs if the Vec is non-empty since it will always exit on first iteration.

Assuming you do want to support concurrently accepting multiple streams, you shouldn't use ready! as that will propogate Poll::Pendings - instead, poll each future in the loop and ignore any Poll::Pendings manually.

1 Like

Great catch. Thank you!

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.