Axum: within the standard chat example, how would you implement multiple chat rooms?

I've been following the chat example to get a better understanding of axum and websockets but I am having an issue conceptualizing how one would create multiple chat rooms. The only way I have thought to do it would be to add a parameter to the AppState struct that has some sort of Vec<u32> for storing chat room ids. Then when a new chat room is made, create a new chat room id, and within the send_task section add this chat room id to the outgoing message (within a stringified json or something). And then add a check in the recv_task section that only broadcasts messages to users with the same active chat room id.

I feel like I might be missing something though. In the main function of the chat example, a single Sender and Receiver are made through the broadcast::channel(100) function (a function in the tokio::sync). Essentially what I am looking to do is create new Senders / Receivers when a new broadcast::channel(100) is triggered, but I am a little lost on how to implement that concept. Any advice? Thank you :sunglasses:

You could probably make a "room ID" based system, but that would have one major drawback. Every message in every room would have to wake every single task associated with listening for new messages in order to check if the channel ID for the message matched.

Another strategy would be to observe that the code in the example is essentially already defining a single chat room. If we can gather up all of the state associated with that into a "room" data structure, then it should be pretty straightforward to add multiple chat rooms.

I would suggest trying to do that yourself first as an exercise, but if you're stuck I'll include my version in a show more below.

As a hint: you can just create a new broadcast::channel(100) for each newly created room instead of doing it in main()

Solution

main.rs

//! Example chat application.
//!
//! Run with
//!
//! ```not_rust
//! cd examples && cargo run -p example-chat
//! ```

use axum::{
    extract::{
        ws::{Message, WebSocket, WebSocketUpgrade},
        State,
    },
    response::{Html, IntoResponse},
    routing::get,
    Router,
};
use futures::{sink::SinkExt, stream::StreamExt};
use serde::Deserialize;
use std::{
    collections::{HashMap, HashSet},
    net::SocketAddr,
    sync::{Arc, Mutex},
};
use tokio::sync::broadcast;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

// Our shared state
struct AppState {
    /// Keys are the name of the channel
    rooms: Mutex<HashMap<String, RoomState>>,
}

struct RoomState {
    /// Previously stored in AppState
    user_set: HashSet<String>,
    /// Previously created in main.
    tx: broadcast::Sender<String>,
}

impl RoomState {
    fn new() -> Self {
        Self {
            // Track usernames per room rather than globally.
            user_set: HashSet::new(),
            // Create a new channel for every room
            tx: broadcast::channel(100).0,
        }
    }
}

#[tokio::main]
async fn main() {
    tracing_subscriber::registry()
        .with(tracing_subscriber::EnvFilter::new(
            std::env::var("RUST_LOG").unwrap_or_else(|_| "example_chat=trace".into()),
        ))
        .with(tracing_subscriber::fmt::layer())
        .init();

    let app_state = Arc::new(AppState {
        rooms: Mutex::new(HashMap::new()),
    });

    let app = Router::with_state(app_state)
        .route("/", get(index))
        .route("/websocket", get(websocket_handler));

    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    tracing::debug!("listening on {}", addr);
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await
        .unwrap();
}

async fn websocket_handler(
    ws: WebSocketUpgrade,
    State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
    ws.on_upgrade(|socket| websocket(socket, state))
}

async fn websocket(stream: WebSocket, state: Arc<AppState>) {
    // By splitting we can send and receive at the same time.
    let (mut sender, mut receiver) = stream.split();

    // Username gets set in the receive loop, if it's valid.

    // We have more state now that needs to be pulled out of the connect loop
    let mut tx = None::<broadcast::Sender<String>>;
    let mut username = String::new();
    let mut channel = String::new();

    // Loop until a text message is found.
    while let Some(Ok(message)) = receiver.next().await {
        if let Message::Text(name) = message {
            #[derive(Deserialize)]
            struct Connect {
                username: String,
                channel: String,
            }

            let connect: Connect = match serde_json::from_str(&name) {
                Ok(connect) => connect,
                Err(error) => {
                    tracing::error!(%error);
                    let _ = sender
                        .send(Message::Text(String::from(
                            "Failed to parse connect message",
                        )))
                        .await;
                    break;
                }
            };

            // Scope to drop the mutex guard before the next await
            {
                // If username that is sent by client is not taken, fill username string.
                let mut rooms = state.rooms.lock().unwrap();

                channel = connect.channel.clone();
                let room = rooms.entry(connect.channel).or_insert_with(RoomState::new);

                tx = Some(room.tx.clone());

                if !room.user_set.contains(&connect.username) {
                    room.user_set.insert(connect.username.to_owned());
                    username = connect.username.clone();
                }
            }

            // If not empty we want to quit the loop else we want to quit function.
            if tx.is_some() && !username.is_empty() {
                break;
            } else {
                // Only send our client that username is taken.
                let _ = sender
                    .send(Message::Text(String::from("Username already taken.")))
                    .await;

                return;
            }
        }
    }

    // We know if the loop exited `tx` is not `None`.
    let tx = tx.unwrap();
    // Subscribe before sending joined message.
    let mut rx = tx.subscribe();

    // Send joined message to all subscribers.
    let msg = format!("{} joined.", username);
    tracing::debug!("{}", msg);
    let _ = tx.send(msg);

    // This task will receive broadcast messages and send text message to our client.
    let mut send_task = tokio::spawn(async move {
        while let Ok(msg) = rx.recv().await {
            // In any websocket error, break loop.
            if sender.send(Message::Text(msg)).await.is_err() {
                break;
            }
        }
    });

    // We need to access the `tx` variable directly again, so we can't shadow it here.
    // I moved the task spawning into a new block so the original `tx` is still visible later.
    let mut recv_task = {
        // Clone things we want to pass to the receiving task.
        let tx = tx.clone();
        let name = username.clone();

        // This task will receive messages from client and send them to broadcast subscribers.
        tokio::spawn(async move {
            while let Some(Ok(Message::Text(text))) = receiver.next().await {
                // Add username before message.
                let _ = tx.send(format!("{}: {}", name, text));
            }
        })
    };

    // If any one of the tasks exit, abort the other.
    tokio::select! {
        _ = (&mut send_task) => recv_task.abort(),
        _ = (&mut recv_task) => send_task.abort(),
    };

    // Send user left message.
    let msg = format!("{} left.", username);
    tracing::debug!("{}", msg);
    let _ = tx.send(msg);
    let mut rooms = state.rooms.lock().unwrap();

    // Remove username from map so new clients can take it.
    rooms.get_mut(&channel).unwrap().user_set.remove(&username);

    // TODO: Check if the room is empty now and remove the `RoomState` from the map.
}

// Include utf-8 file at **compile** time.
async fn index() -> Html<&'static str> {
    Html(std::include_str!("../chat.html"))
}

chat.html

<!DOCTYPE html>
<html lang="en">

<head>
    <meta charset="UTF-8">
    <title>WebSocket Chat</title>
</head>

<body>
    <h1>WebSocket Chat Example</h1>

    <input id="username" style="display:block; width:100px; box-sizing: border-box" type="text" placeholder="username">
    <input id="channel" style="display:block; width:100px; box-sizing: border-box" type="text" placeholder="channel">
    <button id="join-chat" type="button">Join Chat</button>
    <textarea id="chat" style="display:block; width:600px; height:400px; box-sizing: border-box" cols="30"
        rows="10"></textarea>
    <input id="input" style="display:block; width:600px; box-sizing: border-box" type="text" placeholder="chat">

    <script>
        const username = document.querySelector("#username");
        const channel = document.querySelector('#channel');
        const join_btn = document.querySelector("#join-chat");
        const textarea = document.querySelector("#chat");
        const input = document.querySelector("#input");

        join_btn.addEventListener("click", function (e) {
            this.disabled = true;

            const websocket = new WebSocket("ws://localhost:3000/websocket");

            websocket.onopen = function () {
                console.log("connection opened");
                websocket.send(JSON.stringify({ username: username.value, channel: channel.value }));
            }

            const btn = this;

            websocket.onclose = function () {
                console.log("connection closed");
                btn.disabled = false;
            }

            websocket.onmessage = function (e) {
                console.log("received message: " + e.data);
                textarea.value += e.data + "\r\n";
            }

            input.onkeydown = function (e) {
                if (e.key == "Enter") {
                    websocket.send(input.value);
                    input.value = "";
                }
            }
        });
    </script>
</body>

</html>

SO tempted to click that solution button but definitely giving it a try on my own haha. I'm not 100% sure yet if I'd need to spawn additional threads in order to do this but I guess I'll find out soon enough. Thank you!!!