Async closure borrowing issues - configurable socket connection handler

Hello everyone!

This is my first time asking for help here and I am also a non-native speaker so please bear with me.

What I want to do seems simple, yet is so difficult for me in Rust.

I have two websockets and when I get a message from one websocket, I want to pass it to the other

use async_recursion::async_recursion;
use futures_util::{
    SinkExt,
    StreamExt,
    stream::{SplitSink, SplitStream}
};
use std::error::Error;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::task;
use tokio_tungstenite::{
    connect_async, tungstenite::protocol::Message, MaybeTlsStream, WebSocketStream,
};
use tracing::info;

pub type WebSocketWrite = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
pub type WebSocketRead = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;

const WEBSOCKET_PATH: &str = "ws://localhost:6666/ws";
const OTHER_WEBSOCKET_PATH: &str = "ws://remote.url:8080/ws";

struct WebSocketHandler {
    connection_attempts: u8,
    max_connection_attempts: u8,
    pub read: Option<WebSocketRead>,
    pub write: Option<Arc<Mutex<WebSocketWrite>>>,
}

impl WebSocketHandler {
    async fn new(max_connection_attempts: u8) -> Self {
        Self {
            connection_attempts: 0,
            max_connection_attempts,
            read: None,
            write: None,
        }
    }

    #[async_recursion]
    async fn connect(mut self, websocket_path: &str) -> Result<Self, String> {
        self.connection_attempts += 1;
        info!("Connection attempt: {}", self.connection_attempts);
        let attempt = connect_async(websocket_path).await;

        return match attempt {
            Ok(connection) => {
                let (stream, response) = connection;
                info!("Response: {:?}", response);
                let (write, read) = stream.split();
                let write = Arc::new(Mutex::new(write));
                self.read = Some(read);
                self.write = Some(write);
                Ok(self)
            }
            Err(connect_error) => {
                if self.connection_attempts >= self.max_connection_attempts {
                    return Err(format!(
                        "Failed to connect to {}: {:?}", WEBSOCKET_PATH,
                        connect_error
                    ));
                }
                return self.connect(websocket_path).await;
            }
        };
    }

    async fn handle_reads<Fut>(&mut self, mut f: impl FnMut(Message) -> Fut) -> Result<(), &str>
    where
        Fut: std::future::Future<Output = ()>,
    {
        if let None = self.read {
            return Err("Websocket not initialized correctly!");
        }
        let mut read = self.read.take().unwrap();
        while let Some(msg) = read.next().await {
            match msg {
                Ok(msg) => f(msg),
                Err(err) => todo!("Unhandled error {:?}", err),
            };
        }
        Ok(())
    }
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
    pretty_env_logger::init_timed();
    let ws = WebSocketHandler::new(5).await;
    let other_ws = WebSocketHandler::new(5).await;
    let other_ws = other_ws
        .connect(OTHER_WEBSOCKET_PATH)
        .await
        .expect("Other websocket should just work");

    let o_write = other_ws.write.clone().expect("Other websocket to have a write end");

    let connect = task::spawn(async move {
        match ws.connect(WEBSOCKET_PATH).await {
            Ok(mut connection) => {
                let read_handler = connection.handle_reads(|msg| async move {
                    println!("Got a message, awesome! {}", msg);
                    let ret = o_write.lock().await.send(Message::Text(format!("Got a message: {:?}", msg))).await.unwrap();
                    ret
                });

                let _ = read_handler.await;

            }
            Err(err) => {
                panic!("Failed to connect to {}! {:?}", WEBSOCKET_PATH, err);
            }
        }
    });

    let _ = connect.await;

    Ok(())
}

For this I get

error[E0507]: cannot move out of `o_write`, a captured variable in an `FnMut` closure
   --> crates/core/src/main.rs:174:66
    |
169 |       let o_write = other_ws.write.clone().expect("Other websocket to have a write end");
    |           ------- captured outer variable
...
174 |                   let read_handler = connection.handle_reads(|msg| async move {
    |  ____________________________________________________________-----_^
    | |                                                            |
    | |                                                            captured by this `FnMut` closure
175 | |                     println!("Got a message, awesome! {}", msg);
176 | |                     let ret = o_write.lock().await.send(Message::Text(format!("Got a message: {:?}", msg))).await.unwrap();
    | |                               -------
    | |                               |
    | |                               variable moved due to use in generator
    | |                               move occurs because `o_write` has type `Arc<tokio::sync::Mutex<SplitSink<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>>>`, which does not implement the `Copy` trait
177 | |                     ret
178 | |                 });
    | |_________________^ move out of `o_write` occurs here

I can't seem to wrap my head around what I should do differently.

In the end I want to introspect the message and depending on what it is, send it or not.

I also want to handle failing websocket connections by trying to establish them again.
(I saw there was a crate for that, but I figured I should try this is an exercise myself)

Try something like

    let connect = task::spawn(async move {
        match ws.connect(WEBSOCKET_PATH).await {
            Ok(mut connection) => {
+               let o_write = &o_write;
                let read_handler = connection.handle_reads(|msg| async move {
                    println!("Got a message, awesome! {}", msg);
                    let ret = o_write.lock().await.send(Message::Text(format!("Got a message: {:?}", msg))).await.unwrap();
                    ret
                });

                let _ = read_handler.await;

            }
            Err(err) => {
                panic!("Failed to connect to {}! {:?}", WEBSOCKET_PATH, err);
            }
        }
    });
1 Like

So the problem is that the |msg| async move { … } closure and the contained async block capture all variables by-value, i.e. by moving, because of the use of the move keyword on the async block. But without this move keyword, the usage of the closure argument msg within the async block would be read-only, so the async block would create a future that is not allowed to be returned from the closure, because the local variable / function argument msg from inside the closure would be borrowed by it.

Ideally, Rust would have a more first-class support for some sort of “async closures” that avoid this problem, moving all closure arguments into the future by-value, whilst capturing variables from outside only by-reference if by-value is not necessary. For now, there’s the two alternatives of

  • removing the move but accessing msg in a way that a move of msg happens nonetheless
  • making sure that for o_write (and in general for captured variables you don’t want to unnecessarily move) the thing being moved itself is only a reference.

The second idea is what I suggested above, whereas the first idea could look like

    let connect = task::spawn(async move {
        match ws.connect(WEBSOCKET_PATH).await {
            Ok(mut connection) => {
-               let read_handler = connection.handle_reads(|msg| async move {
+               let read_handler = connection.handle_reads(|msg| async {
+                   let msg = msg;
                    println!("Got a message, awesome! {}", msg);
                    let ret = o_write.lock().await.send(Message::Text(format!("Got a message: {:?}", msg))).await.unwrap();
                    ret
                });

                let _ = read_handler.await;

            }
            Err(err) => {
                panic!("Failed to connect to {}! {:?}", WEBSOCKET_PATH, err);
            }
        }
    });

Note that I didn’t spend the time testing the code and the solutions I proposed; feel free to ask follow-up questions if problems remain or my suggestions don’t work at all.

Omg... this makes sense and helps a lot. Thanks a bunch!

Is there any up or downsides to consider for the two approaches?

Not in terms of what they do at run-time, no, they should be equivalent. Choose whichever feels more ergonomic to you personally.


Well maybe one thing is that for other argument types, the msg = msg trick might not actually work at all. E.g. if the type implements Copy, it would still be captured by-reference even with a let msg = msg;. So the “keep it async move, but capture references by-value” approach is more generally applicable, i.e. it’s an approach that will always work in such situations; so if your goal is “consistency” of sorts, that could be an argument in favor of that approach.

On the other hand, in situations where you have few (e.g. a single) non-Copy arguments, but a lot of captured variables, the approach without the move can be a lot lighter than creating references for all captured variables manually.

1 Like

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.