How to implement Tokio AsyncRead and AsyncWrite for custom struct that wrap websocket

Hi all,

i'm new to rust but have experience in go. i would like to try to add websocket transport to https://github.com/rapiz1/rathole using fastwebsocket but i'm currently stuck at how to convert Websocket into tokio's Asyncread and Asyncwrite. I don't know if this is possible or not, and currently i'm confused in how to invoke readframe async function insife poll_read which is an sync function. How should i do it?

thanks

Why exactly do you want to implement AsyncRead and AsyncWrite for WebSocket? Which is nothing you can do, as you are not allowed to implement a foreign trait on a foreign type. WebSocket is just a wrapper around a stream that already implements AsyncRead and AsyncWrite. Can you solve your problem by unwrapping the stream again with WebSocket::into_inner?

Hi, thanks for the reply

to implement this trait.

My idea is to create a struct that implement Asyncwrite and Asyncread that wraps write_frame and read_frame

i've tried that, unfortunately i got this error when using WebSocket::into_inner

(dyn hyper::upgrade::Io + Send + 'static)` cannot be shared between threads safely
the trait `Sync` is not implemented for `(dyn hyper::upgrade::Io + Send + 'static)`
required for `Unique<(dyn hyper::upgrade::Io + Send + 'static)>` to implement `Sync`

Also, if i use into_inner, the communication wouldn't adhere to websocket protocol right?

It fundamentally does not make sense to implement AsyncRead/AsyncWrite for websockets. A websocket is a sequence of frames, but AsyncRead/AsyncWrite only makes sense for byte streams. These are not the same things.

1 Like

Hi, thanks for the response.

It might be not make sense from rust point of view, but is it possible to do it?

I mean in go, this could easily be done. There is even a library that natively provide websocket api similar to tcp api. It is useful to be able to tunnel tcp or udp traffic using websocket especially in and environment where only incoming https connection is allowed from untrusted network. This kind of tunneling is even used in kubernetes (kubectl port forward or kubevpn)

This has nothing to with Rust. You have a stream of frames, but you want a stream of bytes. They're not the same thing. Doing this requires a conversion step of some kind — it's not just a simple wrapper.

I suppose you could reimplement the same conversion as what Go comes with.

The fastwebsocket library you're using takes &mut self for both the read frame and write frame methods, which means that reading and writing cannot happen at the same time. On the other hand, the AsyncRead and AsyncWrite traits do allow for that. It's not going to be easy to wrap that library.

Perhaps you could use tokio-tungstenite instead. Since it's based on the Sink/Stream traits, it shouldn't have this problem.

1 Like

Hmm, with tokio-tungstenite, implementing AsyncWrite like how Go converts them is relatively easy. However, AsyncRead will be annoying to do. You will need something similar to StreamReader, but you can't use it directly because you lose the AsyncWrite part.

This should be a good start. I did not try to compile it — it's just a sketch:


struct WebsocketTunnel {
    inner: StreamReader<StreamWrapper>,
}

struct StreamWrapper {
    inner: TungsteniteWebSocket,
}

impl Stream for StreamWrapper {
    type Item;

    fn poll_next(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>
    ) -> Poll<Option<Self::Item>> {
        match Pin::new(&mut self.inner).poll_next(cx) {
            Poll::Pending => Poll::Pending,
            Poll::Ready(None) => Poll::Ready(None),
            Poll::Ready(Some(Ok(frame))) => Poll::Ready(Some(Ok(convert_frame(frame)))),
            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(convert_error(err)))),
        }
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        self.inner.size_hint()
    }
}

impl AsyncRead for WebsocketTunnel {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>
    ) -> Poll<Result<()>> {
        Pin::new(&mut self.inner).poll_read(cx, buf)
    }
}

impl AsyncBufRead for WebsocketTunnel {
    fn poll_fill_buf(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>
    ) -> Poll<Result<&[u8]>> {
        Pin::new(&mut self.inner).poll_fill_buf(cx)
    }

    fn consume(self: Pin<&mut Self>, amt: usize) {
        Pin::new(&mut self.inner).consume(amt)
    }
}

impl AsyncWrite for WebsocketTunnel {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<Result<usize, io::Error>> {
        let inner = &mut self.inner.inner;
        ready!(Pin::new(inner).poll_ready(cx).map_err(Into::into))?;
        match this.inner.as_mut().start_send(tungstenite::protocol::Message(buf.to_vec())) {
            Ok(()) => Poll::Ready(Ok(buf.len())),
            Err(e) => Poll::Ready(Err(e.into())),
        }
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
        Pin::new(&mut self.inner.inner).poll_flush(cx)
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
        Pin::new(&mut self.inner.inner).poll_close(cx)
    }
}
1 Like

Awesome. Thanks a lot.

This is a lot to digest for me, so let me try and explore it first.
Again, thanks for the direction.

Hi @alice , sorry if this sound stupid, but what would be the correct type on type Item;

from the websocket stream, i could get Vec<u8>, but i'm struggling as to how to convert it to type that implement Buf.

i tried with Result<Box<[u8]>, Error>; and several other for item but i always got this kind of error

the method `poll_read` exists for struct `Pin<&mut StreamReader<StreamWrapper, Box<[u8]>>>`, but its trait bounds were not satisfied
the following trait bounds were not satisfied:
`Box<[u8]>: Buf`

You can wrap the box in a Cursor or use the bytes crate.

Hi @alice , thank you very much for the guidance, i think i've successfully wrap the websocketstream into asyncread and asyncwrite based on your code

#[derive(Debug)]
struct WebsocketTunnel {
    inner: StreamReader<StreamWrapper, Bytes>,
}

#[derive(Debug)]
struct StreamWrapper {
    inner: WebSocketStream<Upgraded>,
}

impl Stream for StreamWrapper {
    type Item = Result<Bytes, Error>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        match Pin::new(&mut self.get_mut().inner).poll_next(cx) {
            Poll::Pending => Poll::Pending,
            Poll::Ready(None) => Poll::Ready(None),
            Poll::Ready(Some(Ok(res))) => {
                if let Message::Binary(b) = res {
                    Poll::Ready(Some(Ok(Bytes::from(b))))
                } else {
                    Poll::Ready(Some(Err(Error::new(ErrorKind::Other, "unexpected frame"))))
                }
            }
            Poll::Ready(Some(Err(err))) => {
                Poll::Ready(Some(Err(Error::new(ErrorKind::Other, err))))
            }
        }
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        self.inner.size_hint()
    }
}

impl AsyncRead for WebsocketTunnel {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<std::io::Result<()>> {
        Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
    }
}

impl AsyncBufRead for WebsocketTunnel {
    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
        Pin::new(&mut self.get_mut().inner).poll_fill_buf(cx)
    }

    fn consume(self: Pin<&mut Self>, amt: usize) {
        Pin::new(&mut self.get_mut().inner).consume(amt)
    }
}

impl AsyncWrite for WebsocketTunnel {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<Result<usize, std::io::Error>> {
        let mut inner = self.get_mut().inner.get_mut().inner;
        ready!(Pin::new(&mut inner)
            .poll_ready(cx)
            .map_err(|err| Error::new(ErrorKind::Other, err)))?;
        match Pin::new(&mut inner).start_send(Message::Binary(buf.to_vec())) {
            Ok(()) => Poll::Ready(Ok(buf.len())),
            Err(e) => Poll::Ready(Err(Error::new(ErrorKind::Other, e))),
        }
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
        Pin::new(&mut self.get_mut().inner.get_mut().inner)
            .poll_flush(cx)
            .map_err(|err| Error::new(ErrorKind::Other, err))
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
        Pin::new(&mut self.get_mut().inner.get_mut().inner)
            .poll_close(cx)
            .map_err(|err| Error::new(ErrorKind::Other, err))
    }
}

But i'm current facing this error when trying to pass the wrapper as Stream needed by the original program

`(dyn upgrade::Io + Send + 'static)` cannot be shared between threads safely
the trait `Sync` is not implemented for `(dyn upgrade::Io + Send + 'static)`
required for `Unique<(dyn upgrade::Io + Send + 'static)>` to implement `Sync`

I've read that this require unsafe code. Do you have any workaround for this? or maybe is it possible to wrap hyper::upgrade::Upgraded so that it is Sync?

What is triggering that error? Usually, you should not need Sync for IO resources.

1 Like

this line

Your Transport trait is a subtrait of Sync so you can't implement it on types that are not Sync. Or rather your Transport::Stream type must be Sync which it isn't.

Thanks @jofas
is it possible to wrap an !Sync type so that it sync? Arc<Mutex< maybe ? Or is there any better alternative?

You can wrap Send types in a Mutex to make them Sync but are you sure you need the Sync bound in the first place?

Thanks

Actually the one who write the trait is not me. I'm just trying to implement it using websocket as the underlying transport as i'm trying to learn rust.

Just use the sync_wrapper crate. You can wrap the StreamReader<StreamWrapper, Bytes> with the SyncWrapper utility.

4 Likes

work like a charm, awesome. Thanks a lot