How to .await inside .retain callback?

I'm using the below code with unbounded mpsc channel and it works amazingly well.

Now I need to use bounded mpsc channels and it doesn't work anymore because now I need to ".await" on .send() because it can be full.

Since I would like to remove empty Vec and empty HashMap keys I'm using retain which unfortunately doesn't support async callback:

pub async fn broadcast(&self, message: Message) {
    let clients = self.clients.clone();

    let message = Arc::new(message);

    tokio::spawn(async move {
        let mut clients = clients.lock().unwrap();

        let players = clients.get_mut(&message.team_id).unwrap();

        // I would like to remove dropped connections and remove empty Vec too below

        // It works with unbounded channels because I don't need to ".await" on it, now with bounded channel I need to ".await" because it can be full

        players.retain(|_, emitters| {
            emitters.retain(|emitter| {
                // Better logic here...
                emitter.sender.send(message.clone()).await.is_ok()
            });

            !emitters.is_empty()
        });
    });
}

REPL: https://www.rustexplorer.com/b/hh5m79

Code:

/*
[dependencies]
axum = { version = "0.6.20" }
futures = { version = "0.3.28", default-features = false }
tokio = { version = "1.32.0", default-features = false, features = [
    "macros",
    "process",
    "rt-multi-thread",
] }
tokio-stream = { version = "0.1.14", default-features = false, features = [
    "sync",
] }
*/

use axum::{
    extract::State,
    response::{
        sse::{Event, KeepAlive, Sse},
        Html,
    },
    routing::get,
    Router,
};
use futures::stream::Stream;
use std::{
    collections::HashMap,
    convert::Infallible,
    sync::{Arc, Mutex},
};
use tokio::sync::mpsc;
use tokio_stream::{wrappers::ReceiverStream, StreamExt};

type TeamId = String;
type PlayerId = String;

#[derive(Default)]
pub struct Broadcaster {
    clients: Arc<Mutex<HashMap<TeamId, HashMap<PlayerId, Vec<Connection>>>>>,
}

pub struct Connection {
    session_id: String,
    player_id: String,
    sender: mpsc::Sender<Arc<Message>>,
}

pub struct Message {
    pub team_id: TeamId,
    pub session_id: String,
    pub message: String,
}

struct AppState {
    broadcaster: Arc<Broadcaster>,
}

#[tokio::main]
async fn main() {
    let broadcaster = Arc::new(Broadcaster::default());

    let app_state = Arc::new(AppState { broadcaster });

    let app = Router::new()
        .route("/send_message", get(send_message))
        .route("/sse", get(sse_handler))
        .with_state(app_state);

    axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
        .serve(app.into_make_service())
        .await
        .unwrap();
}

async fn send_message(State(app_state): State<Arc<AppState>>) -> Html<&'static str> {
    let new_fake_message = Message {
        team_id: "fake_one".to_string(),
        session_id: "fake_one".to_string(),
        message: "fake_one".to_string(),
    };

    app_state.broadcaster.broadcast(new_fake_message).await;

    Html("Message sent")
}

async fn sse_handler(
    State(app_state): State<Arc<AppState>>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
    let rx = app_state
        .broadcaster
        .add_client("fake_one", "fake_one", "fake_one")
        .await;

    let mystream = ReceiverStream::<Arc<Message>>::new(rx)
        .map(|res| Ok(Event::default().data(res.message.to_string())));

    Sse::new(mystream).keep_alive(KeepAlive::default())
}

impl Broadcaster {
    pub async fn add_client(
        &self,
        session_id: &str,
        team_id: &str,
        player_id: &str,
    ) -> mpsc::Receiver<Arc<Message>> {
        let (tx, rx) = mpsc::channel::<Arc<Message>>(10);

        let mut clients = self.clients.lock().unwrap();

        if !clients.contains_key(team_id) {
            clients.insert(team_id.to_string(), HashMap::new());
        }

        let players = clients.get_mut(team_id).unwrap();

        if !players.contains_key(player_id) {
            players.insert(player_id.to_string().into(), Vec::new());
        }

        let connections = players.get_mut(player_id).unwrap();

        let connection = Connection {
            session_id: session_id.to_string(),
            player_id: player_id.to_string(),
            sender: tx,
        };

        connections.push(connection);

        rx
    }

    pub async fn broadcast(&self, message: Message) {
        let clients = self.clients.clone();

        let message = Arc::new(message);

        tokio::spawn(async move {
            let mut clients = clients.lock().unwrap();

            let players = clients.get_mut(&message.team_id).unwrap();

            // I would like to remove dropped connections and remove empty Vec too

            // It works with unbounded channels because I don't need to ".await" on it, now with bounded channel I need to await because it can be full

            players.retain(|_, emitters| {
                emitters.retain(|emitter| {
                    // Better logic here...
                    emitter.sender.send(message.clone()).await.is_ok()
                });

                !emitters.is_empty()
            });
        });
    }
}

tokio::sync::mpsc::Sender has a blocking_send method you don't need to .await. Never mind, you can't call it from within the tokio runtime.

You could use it in conjunction with spawn_blocking. That does not trigger the panic when calling blocking_send from within the runtime:

        tokio::task::spawn_blocking(move || {
            let mut clients = clients.lock().unwrap();

            let players = clients.get_mut(&message.team_id).unwrap();

            // I would like to remove dropped connections and remove empty Vec too

            // It works with unbounded channels because I don't need to ".await" on it, now with bounded channel I need to await because it can be full

            players.retain(|_, emitters| {
                emitters.retain(|emitter| {
                    // Handle better logic here...
                    emitter.sender.blocking_send(message.clone()).is_ok()
                });

                !emitters.is_empty()
            });
        })
        .await
        .unwrap();

But this makes everything no longer asynchronous but synchronous.

That makes a few send calls blocking, only if their queues are full, all on a dedicated thread that can process as many players and emitters as it likes, without bothering the rest of the runtime. I'd be much more concerned about locking clients for the whole operation, rather than the blocking sends.

1 Like

Yes, I also had doubts about clients.lock().

Now my head is exploding over these two issues.

How can I better understand what is happening?

Why do you say

only if their queues are full

? Isn't the blocking send called each time?

It is, but if the receiver's queue is not full, blocking_send will return immediately. Otherwise it blocks till it can place the message onto the queue of the receiver.

To be honest, the answer is "don't".

You could replace the retain call with a for loop. You could also write the messages to send to a vector, then send them after the call to retain.

3 Likes

Thanks. Can I ask for an example, please?

And how to detect if the send is in error (!.is_ok)?

Using this:

pub async fn broadcast(&self, message: Message) {
    let clients = self.clients.clone();

    let message = Arc::new(message);

    tokio::spawn(async move {
        let mut clients = clients.lock().unwrap();

        let players = clients.get_mut(&message.team_id).unwrap();

        for (player_id, connections) in players {
            for connection in connections {
                connection.sender.send(message.clone()).await;
            }
        }
    });
}

the error is:

error: future cannot be sent between threads safely
   --> src\main.rs:121:22
    |
121 |           tokio::spawn(async move {
    |  ______________________^
123 | |             let mut clients = clients.lock().unwrap();
124 | |
...   |
131 | |             }
132 | |         });
    | |_________^ future created by async block is not `Send`
    |
    = help: within `[async block@src\main.rs:121:22: 132:10]`, the trait `Send` is not implemented for `std::sync::MutexGuard<'_, HashMap<String, HashMap<String, Vec<Connection>>>>`
note: future is not `Send` as this value is used across an await
   --> src\main.rs:129:61
    |
123 |             let mut clients = clients.lock().unwrap();
    |                 ----------- has type `std::sync::MutexGuard<'_, HashMap<String, HashMap<String, Vec<Connection>>>>` which is not `Send`
...
129 |                     connection.sender.send(message.clone()).await;
    |                                                             ^^^^^ await occurs here, with `mut clients` maybe used later
...
132 |         });
    |         - `mut clients` is later dropped here
note: required by a bound in `tokio::spawn`
   --> C:\Users\User\.cargo\registry\src\index.crates.io-6f17d22bba15001f\tokio-1.33.0\src\task\spawn.rs:166:21
    |
164 |     pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
    |            ----- required by a bound in this function
165 |     where
166 |         F: Future + Send + 'static,
    |                     ^^^^ required by this bound in `spawn`

If you keep the lock across .await points you need to replace the Mutex from the standard library with the one from tokio (tokio::sync::Mutex).

Yeah. I just read Shared state | Tokio - An asynchronous Rust runtime (again after months). Thanks.

Do you know what this means?

You could replace the retain call with a for loop. You could also write the messages to send to a vector, then send them after the call to retain .

My interpretation would be something like this, but maybe (probably) alice had something better in mind:

        tokio::spawn(async move {
            let mut clients = clients.lock().await;

            let players = clients.remove(&message.team_id).unwrap();
            
            let mut retained_players = HashMap::new();

            for (key, emitters) in players {
                let mut retained_emitters = Vec::new();

                for emitter in emitters {
                    if emitter.sender.send(message.clone()).await.is_ok() {
                        retained_emitters.push(emitter);
                    }
                }

                if !retained_emitters.is_empty() {
                    retained_players.insert(key, retained_emitters);
                }
            }
            
            clients.insert(message.team_id.clone(), retained_players);
        });

Thanks! The bad side of this is the new allocation ON EACH broadcast call.

Right?

The compiler is pretty good at optimising stuff. The only way to know for sure if this will be a bottleneck is running your application under production load and measure

Yeah. I'll measure for sure. But now I already know that I will definitely broadcast lots of messages to many different teams.

So broadcast() will be called a lot.

I'm thinking of a different way to clean dropped connections. Using a technique I saw here: https://github.com/chaudharypraveen98/actix-sse-example/blob/master/src/broadcast.rs#L30.

In practice every certain time it launches a task that sends a ping, checking the dropped connections (as I'm doing now).

And I'm also thinking of using an RWLock so as to block Broadcaster from writing only when needed:

  1. when I insert a new connection (add_client())

  2. when I remove dropped connection.

What do you think?

You can avoid allocating by keeping the spare collections around in self.

If you don't care about order, then the inner retain is easy to get rid of.

let mut i = 0;
while let Some(emitter) = emitters.get(i) {
    if emitter.sender.send(message.clone()).await.is_ok() {
        i += 1;
        continue;
    }
    emitters.swap_remove(i);
}

You could also do an ordered version of this, but it would be more complicated and slower.

HashMap is more difficult, or at least less efficient.

// you don't have to build this every time if you keep it updated
let player_ids: Vec<PlayerId> = players.keys().cloned().collect();
for player_id in &player_ids {
    let Some(emitters) = players.get_mut(player_id) else {
        unreachable!()
    };

    let mut i = 0;
    while let Some(emitter) = emitters.get(i) {
        if emitter.sender.send(message.clone()).await.is_ok() {
            i += 1;
            continue;
        }
        emitters.swap_remove(i);
    }

    if emitters.is_empty() {
        players.remove(player_id);
    }
}

The other thing you can do is just make a list of keys/indices that need deletion, and use retain after sending all the messages.

Honestly, these retain functions could somewhat easily be made async since they already need to deal with panic safety. Copying the code from std and making it async would be sensible if you need the performance.

I agree that this should be done less frequently than broadcast. Maybe you could keep track of how many channels are closed and only clean them up when there's a lot of them.

Can I ask you to show how please? :pray:

Save the number of disconnections.

pub struct Broadcaster {
    clients: Arc<Mutex<TeamPlayerMap>>,
    disconnected_channels: u32,
}

Put the sender in an Option. Or you could put the whole Connection in the Option if you want.

pub struct Connection {
    session_id: String,
    player_id: String,
    sender: Option<mpsc::Sender<Arc<Message>>>,
}

Then whenever you fail to send, add one to disconnected_channels and replace sender with None.

And if the count ever reaches some limit (maybe a fraction of the total channels), then you'd call the cleanup method.