I'm using the below code with unbounded mpsc channel and it works amazingly well.
Now I need to use bounded mpsc channels and it doesn't work anymore because now I need to ".await
" on .send()
because it can be full.
Since I would like to remove empty Vec
and empty HashMap
keys I'm using retain
which unfortunately doesn't support async callback:
pub async fn broadcast(&self, message: Message) {
let clients = self.clients.clone();
let message = Arc::new(message);
tokio::spawn(async move {
let mut clients = clients.lock().unwrap();
let players = clients.get_mut(&message.team_id).unwrap();
// I would like to remove dropped connections and remove empty Vec too below
// It works with unbounded channels because I don't need to ".await" on it, now with bounded channel I need to ".await" because it can be full
players.retain(|_, emitters| {
emitters.retain(|emitter| {
// Better logic here...
emitter.sender.send(message.clone()).await.is_ok()
});
!emitters.is_empty()
});
});
}
REPL: https://www.rustexplorer.com/b/hh5m79
Code:
/*
[dependencies]
axum = { version = "0.6.20" }
futures = { version = "0.3.28", default-features = false }
tokio = { version = "1.32.0", default-features = false, features = [
"macros",
"process",
"rt-multi-thread",
] }
tokio-stream = { version = "0.1.14", default-features = false, features = [
"sync",
] }
*/
use axum::{
extract::State,
response::{
sse::{Event, KeepAlive, Sse},
Html,
},
routing::get,
Router,
};
use futures::stream::Stream;
use std::{
collections::HashMap,
convert::Infallible,
sync::{Arc, Mutex},
};
use tokio::sync::mpsc;
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
type TeamId = String;
type PlayerId = String;
#[derive(Default)]
pub struct Broadcaster {
clients: Arc<Mutex<HashMap<TeamId, HashMap<PlayerId, Vec<Connection>>>>>,
}
pub struct Connection {
session_id: String,
player_id: String,
sender: mpsc::Sender<Arc<Message>>,
}
pub struct Message {
pub team_id: TeamId,
pub session_id: String,
pub message: String,
}
struct AppState {
broadcaster: Arc<Broadcaster>,
}
#[tokio::main]
async fn main() {
let broadcaster = Arc::new(Broadcaster::default());
let app_state = Arc::new(AppState { broadcaster });
let app = Router::new()
.route("/send_message", get(send_message))
.route("/sse", get(sse_handler))
.with_state(app_state);
axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
.serve(app.into_make_service())
.await
.unwrap();
}
async fn send_message(State(app_state): State<Arc<AppState>>) -> Html<&'static str> {
let new_fake_message = Message {
team_id: "fake_one".to_string(),
session_id: "fake_one".to_string(),
message: "fake_one".to_string(),
};
app_state.broadcaster.broadcast(new_fake_message).await;
Html("Message sent")
}
async fn sse_handler(
State(app_state): State<Arc<AppState>>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let rx = app_state
.broadcaster
.add_client("fake_one", "fake_one", "fake_one")
.await;
let mystream = ReceiverStream::<Arc<Message>>::new(rx)
.map(|res| Ok(Event::default().data(res.message.to_string())));
Sse::new(mystream).keep_alive(KeepAlive::default())
}
impl Broadcaster {
pub async fn add_client(
&self,
session_id: &str,
team_id: &str,
player_id: &str,
) -> mpsc::Receiver<Arc<Message>> {
let (tx, rx) = mpsc::channel::<Arc<Message>>(10);
let mut clients = self.clients.lock().unwrap();
if !clients.contains_key(team_id) {
clients.insert(team_id.to_string(), HashMap::new());
}
let players = clients.get_mut(team_id).unwrap();
if !players.contains_key(player_id) {
players.insert(player_id.to_string().into(), Vec::new());
}
let connections = players.get_mut(player_id).unwrap();
let connection = Connection {
session_id: session_id.to_string(),
player_id: player_id.to_string(),
sender: tx,
};
connections.push(connection);
rx
}
pub async fn broadcast(&self, message: Message) {
let clients = self.clients.clone();
let message = Arc::new(message);
tokio::spawn(async move {
let mut clients = clients.lock().unwrap();
let players = clients.get_mut(&message.team_id).unwrap();
// I would like to remove dropped connections and remove empty Vec too
// It works with unbounded channels because I don't need to ".await" on it, now with bounded channel I need to await because it can be full
players.retain(|_, emitters| {
emitters.retain(|emitter| {
// Better logic here...
emitter.sender.send(message.clone()).await.is_ok()
});
!emitters.is_empty()
});
});
}
}