[Rust] Websocket Server - Feedback required

Hi,

I haven't been working with Rust for so long, previously I worked a lot with Python and C#.

Now I need a performant implementation of a websocket server for a project. I thought it might be worth a try with Rust.

This websocket server should later serve as a kind of interface. Incoming messages should be processed further, for example to change, write or read data in a database.

Furthermore, states and state changes should be distributed to the clients. User management and rights distribution are also planned.

Now, before I continue working on my code, I would like to hear other opinions whether my concept and the chosen architecture make sense for my project.

My concept looks like this:

Individual components such as message_handler and websocket_server are each started in their own threads, which internally process the code asynchronously.

I think with this concept, the advantages of both parallel and asynchronous data processing can be utilized. However, I'm still not sure about one point: to what extent could the channels form a bottleneck and affect performance?

What do you think, are there better or more efficient ways of implementing my plan? Many thanks in advance, I appreciate any constructive advice.

use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use tokio::net::{TcpListener, TcpStream};
use tokio_tungstenite::tungstenite::protocol::Message;
use tokio_tungstenite::accept_async;
use futures::stream::StreamExt;
use futures::sink::SinkExt;

#[derive(Debug)]
struct SocketMessage {
    text: String,
}

struct WebSocketTx {
    sender: Mutex<mpsc::Sender<Message>>,
}

impl WebSocketTx {
    async fn send(&self, message: Message) -> Result<(), mpsc::error::SendError<Message>> {
        let sender = self.sender.lock().await;
        sender.send(message).await
    }
}


async fn broadcaster(
    mut broadcaster_rx: mpsc::Receiver<SocketMessage>,
    clients: Arc<Mutex<Vec<WebSocketTx>>>,
) {
    while let Some(socket_message) = broadcaster_rx.recv().await {
        let clients = clients.lock().await;
        for client in clients.iter() {

            if let Err(e) = client.send(Message::Text(socket_message.text.clone())).await {
                eprintln!("Fehler beim Senden der Nachricht an den Client: {:?}", e);

            }
        }
    }
}


async fn handle_connection(
    raw_stream: TcpStream,
    message_tx: mpsc::Sender<SocketMessage>,
    clients_tx: Arc<Mutex<Vec<WebSocketTx>>>,
) -> Result<(), Box<dyn std::error::Error>> {
    let ws_stream = accept_async(raw_stream).await.expect("Fehler beim WebSocket-Handshake");
    let (mut write, mut read) = ws_stream.split();

    let (client_tx, client_rx) = mpsc::channel(32);
    clients_tx.lock().await.push(WebSocketTx {
        sender: Mutex::new(client_tx),
    });


    let client_rx_stream = tokio_stream::wrappers::ReceiverStream::new(client_rx);
    tokio::spawn(async move {
        write.send_all(&mut client_rx_stream.map(Ok)).await
            .unwrap_or_else(|e| eprintln!("Fehler beim Senden der Nachricht: {}", e));
    });


    while let Some(message_result) = read.next().await {
        match message_result {
            Ok(msg) => {
                if let Message::Text(text) = msg {

                    message_tx.send(SocketMessage { text }).await?;
                }
            }
            Err(e) => {
                eprintln!("Fehler in der WebSocket-Verbindung: {:?}", e);
                break;
            }
        }
    }

    Ok(())
}


async fn server_run(
    message_tx: mpsc::Sender<SocketMessage>,
    clients_tx: Arc<Mutex<Vec<WebSocketTx>>>,
) -> Result<(), std::io::Error> {
    let listener = TcpListener::bind("0.0.0.0:8002").await?;
    println!("Server lauscht auf: 0.0.0.0:8002");

    while let Ok((stream, _)) = listener.accept().await {
        let message_tx_clone = message_tx.clone();
        let clients_tx_clone = clients_tx.clone();


        tokio::spawn(async move {
            if let Err(e) = handle_connection(stream, message_tx_clone, clients_tx_clone).await {
                eprintln!("Fehler beim Behandeln der Verbindung: {:?}", e);
            }
        });
    }

    Ok(())
}


async fn message_handler(
    mut message_rx: mpsc::Receiver<SocketMessage>,
    broadcaster_tx: mpsc::Sender<SocketMessage>,
) {
    while let Some(socket_message) = message_rx.recv().await {
        println!("Nachricht erhalten: {}", socket_message.text);

        if broadcaster_tx.send(socket_message).await.is_err() {
            eprintln!("Fehler beim Senden der Nachricht an den Broadcaster.");

        }
    }
}




#[tokio::main]
async fn main() {

    //let (message_tx, message_rx) = mpsc::channel(32);
    //let (response_tx, response_rx) = mpsc::channel(32);
    //let (broadcaster_tx, broadcaster_rx) = mpsc::channel(32);

    let (message_tx, message_rx) = mpsc::channel::<SocketMessage>(32);
    let (response_tx, response_rx) = mpsc::channel::<SocketMessage>(32);
    let (broadcaster_tx, broadcaster_rx) = mpsc::channel::<SocketMessage>(32);



    let clients_tx = Arc::new(Mutex::new(Vec::new()));



    let message_handler_thread = tokio::task::spawn_blocking(move || {
        let rt = tokio::runtime::Runtime::new().unwrap();
        rt.block_on(message_handler(message_rx, broadcaster_tx))
    });

    let broadcaster_rx_clone = broadcaster_rx;
    let broadcaster_clients_tx_clone = clients_tx.clone();


    let broadcaster_thread = tokio::task::spawn_blocking(move || {
        let rt = tokio::runtime::Runtime::new().unwrap();
        rt.block_on(broadcaster(broadcaster_rx_clone, broadcaster_clients_tx_clone))
    });

    let server_clients_tx_clone = clients_tx.clone();

    let server_thread = tokio::task::spawn_blocking(move || {
        let rt = tokio::runtime::Runtime::new().unwrap();
        rt.block_on(server_run(message_tx, server_clients_tx_clone))
    });




    if let Err(e) = message_handler_thread.await {
        eprintln!("Message handler thread has encountered an error: {:?}", e);
    } else {
        println!("Message handler thread finished successfully.");
    }


    if let Err(e) = broadcaster_thread.await {
        eprintln!("Broadcaster thread has encountered an error: {:?}", e);
    } else {
        println!("Broadcaster thread finished successfully.");
    }


    if let Err(e) = server_thread.await {
        eprintln!("Server thread has encountered an error: {:?}", e);
    } else {
        println!("Server thread finished successfully.");
    }

}

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.