You could probably make a "room ID" based system, but that would have one major drawback. Every message in every room would have to wake every single task associated with listening for new messages in order to check if the channel ID for the message matched.
Another strategy would be to observe that the code in the example is essentially already defining a single chat room. If we can gather up all of the state associated with that into a "room" data structure, then it should be pretty straightforward to add multiple chat rooms.
I would suggest trying to do that yourself first as an exercise, but if you're stuck I'll include my version in a show more below.
As a hint: you can just create a new broadcast::channel(100)
for each newly created room instead of doing it in main()
Solution
main.rs
//! Example chat application.
//!
//! Run with
//!
//! ```not_rust
//! cd examples && cargo run -p example-chat
//! ```
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
response::{Html, IntoResponse},
routing::get,
Router,
};
use futures::{sink::SinkExt, stream::StreamExt};
use serde::Deserialize;
use std::{
collections::{HashMap, HashSet},
net::SocketAddr,
sync::{Arc, Mutex},
};
use tokio::sync::broadcast;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
// Our shared state
struct AppState {
/// Keys are the name of the channel
rooms: Mutex<HashMap<String, RoomState>>,
}
struct RoomState {
/// Previously stored in AppState
user_set: HashSet<String>,
/// Previously created in main.
tx: broadcast::Sender<String>,
}
impl RoomState {
fn new() -> Self {
Self {
// Track usernames per room rather than globally.
user_set: HashSet::new(),
// Create a new channel for every room
tx: broadcast::channel(100).0,
}
}
}
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new(
std::env::var("RUST_LOG").unwrap_or_else(|_| "example_chat=trace".into()),
))
.with(tracing_subscriber::fmt::layer())
.init();
let app_state = Arc::new(AppState {
rooms: Mutex::new(HashMap::new()),
});
let app = Router::with_state(app_state)
.route("/", get(index))
.route("/websocket", get(websocket_handler));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
async fn websocket_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| websocket(socket, state))
}
async fn websocket(stream: WebSocket, state: Arc<AppState>) {
// By splitting we can send and receive at the same time.
let (mut sender, mut receiver) = stream.split();
// Username gets set in the receive loop, if it's valid.
// We have more state now that needs to be pulled out of the connect loop
let mut tx = None::<broadcast::Sender<String>>;
let mut username = String::new();
let mut channel = String::new();
// Loop until a text message is found.
while let Some(Ok(message)) = receiver.next().await {
if let Message::Text(name) = message {
#[derive(Deserialize)]
struct Connect {
username: String,
channel: String,
}
let connect: Connect = match serde_json::from_str(&name) {
Ok(connect) => connect,
Err(error) => {
tracing::error!(%error);
let _ = sender
.send(Message::Text(String::from(
"Failed to parse connect message",
)))
.await;
break;
}
};
// Scope to drop the mutex guard before the next await
{
// If username that is sent by client is not taken, fill username string.
let mut rooms = state.rooms.lock().unwrap();
channel = connect.channel.clone();
let room = rooms.entry(connect.channel).or_insert_with(RoomState::new);
tx = Some(room.tx.clone());
if !room.user_set.contains(&connect.username) {
room.user_set.insert(connect.username.to_owned());
username = connect.username.clone();
}
}
// If not empty we want to quit the loop else we want to quit function.
if tx.is_some() && !username.is_empty() {
break;
} else {
// Only send our client that username is taken.
let _ = sender
.send(Message::Text(String::from("Username already taken.")))
.await;
return;
}
}
}
// We know if the loop exited `tx` is not `None`.
let tx = tx.unwrap();
// Subscribe before sending joined message.
let mut rx = tx.subscribe();
// Send joined message to all subscribers.
let msg = format!("{} joined.", username);
tracing::debug!("{}", msg);
let _ = tx.send(msg);
// This task will receive broadcast messages and send text message to our client.
let mut send_task = tokio::spawn(async move {
while let Ok(msg) = rx.recv().await {
// In any websocket error, break loop.
if sender.send(Message::Text(msg)).await.is_err() {
break;
}
}
});
// We need to access the `tx` variable directly again, so we can't shadow it here.
// I moved the task spawning into a new block so the original `tx` is still visible later.
let mut recv_task = {
// Clone things we want to pass to the receiving task.
let tx = tx.clone();
let name = username.clone();
// This task will receive messages from client and send them to broadcast subscribers.
tokio::spawn(async move {
while let Some(Ok(Message::Text(text))) = receiver.next().await {
// Add username before message.
let _ = tx.send(format!("{}: {}", name, text));
}
})
};
// If any one of the tasks exit, abort the other.
tokio::select! {
_ = (&mut send_task) => recv_task.abort(),
_ = (&mut recv_task) => send_task.abort(),
};
// Send user left message.
let msg = format!("{} left.", username);
tracing::debug!("{}", msg);
let _ = tx.send(msg);
let mut rooms = state.rooms.lock().unwrap();
// Remove username from map so new clients can take it.
rooms.get_mut(&channel).unwrap().user_set.remove(&username);
// TODO: Check if the room is empty now and remove the `RoomState` from the map.
}
// Include utf-8 file at **compile** time.
async fn index() -> Html<&'static str> {
Html(std::include_str!("../chat.html"))
}
chat.html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>WebSocket Chat</title>
</head>
<body>
<h1>WebSocket Chat Example</h1>
<input id="username" style="display:block; width:100px; box-sizing: border-box" type="text" placeholder="username">
<input id="channel" style="display:block; width:100px; box-sizing: border-box" type="text" placeholder="channel">
<button id="join-chat" type="button">Join Chat</button>
<textarea id="chat" style="display:block; width:600px; height:400px; box-sizing: border-box" cols="30"
rows="10"></textarea>
<input id="input" style="display:block; width:600px; box-sizing: border-box" type="text" placeholder="chat">
<script>
const username = document.querySelector("#username");
const channel = document.querySelector('#channel');
const join_btn = document.querySelector("#join-chat");
const textarea = document.querySelector("#chat");
const input = document.querySelector("#input");
join_btn.addEventListener("click", function (e) {
this.disabled = true;
const websocket = new WebSocket("ws://localhost:3000/websocket");
websocket.onopen = function () {
console.log("connection opened");
websocket.send(JSON.stringify({ username: username.value, channel: channel.value }));
}
const btn = this;
websocket.onclose = function () {
console.log("connection closed");
btn.disabled = false;
}
websocket.onmessage = function (e) {
console.log("received message: " + e.data);
textarea.value += e.data + "\r\n";
}
input.onkeydown = function (e) {
if (e.key == "Enter") {
websocket.send(input.value);
input.value = "";
}
}
});
</script>
</body>
</html>