How to detect an a dropped SSE (server sent event) client using Axum?

In the below code I'm handling sse connections within the function sse_handler(). Is there a way to know when that connection is closed so I can remove the client from Bradcaster.inner.clients?

REPL: https://www.rustexplorer.com/b/xkqyt2

/*
[dependencies]
axum = { version = "0.6.20" }
futures = { version = "0.3.28", default-features = false }
tokio = { version = "1.32.0", default-features = false, features = [
    "macros",
    "process",
    "rt-multi-thread",
] }
tokio-stream = { version = "0.1.14", default-features = false, features = [
    "sync",
] }
*/

use axum::{
    extract::State,
    response::{
        sse::{Event, KeepAlive, Sse},
        Html,
    },
    routing::get,
    Router,
};
use futures::stream::Stream;
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
use std::{
    convert::Infallible,
    sync::{Arc, Mutex},
};
use tokio::sync::mpsc;

struct BroadcasterInner {
    clients: Vec<mpsc::Sender<String>>,
}

pub struct Broadcaster {
    inner: Mutex<BroadcasterInner>,
}

struct AppState {
    broadcaster: Arc<Broadcaster>,
}

#[tokio::main]
async fn main() {
    let broadcaster = Broadcaster::new();

    let app_state = Arc::new(AppState { broadcaster });

    let app = Router::new()
        .route("/send_message", get(send_message))
        .route("/sse", get(sse_handler))
        .with_state(app_state);

    axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
        .serve(app.into_make_service())
        .await
        .unwrap();
}

async fn send_message(State(app_state): State<Arc<AppState>>) -> Html<&'static str> {
    app_state.broadcaster.broadcast("message").await;

    Html("Message sent")
}

async fn sse_handler(
    State(app_state): State<Arc<AppState>>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
    let rx = app_state.broadcaster.add_client().await;

    let mystream = ReceiverStream::<String>::new(rx).map(|res| Ok(Event::default().data(res)));

    Sse::new(mystream).keep_alive(KeepAlive::default())
}

impl Broadcaster {
    pub fn new() -> Arc<Self> {
        Arc::new(Broadcaster {
            inner: Mutex::new(BroadcasterInner {
                clients: Vec::new(),
            }),
        })
    }

    pub async fn add_client(&self) -> mpsc::Receiver<String> {
        let (tx, rx) = mpsc::channel::<String>(10);

        tx.send("welcome".to_string()).await.unwrap();

        self.inner.lock().unwrap().clients.push(tx);

        rx
    }

    pub async fn broadcast(&self, event: &str) {
        let clients = self.inner.lock().unwrap().clients.clone();

        let send_futures = clients.iter().map(|client| client.send(event.to_string()));

        let _ = futures::future::join_all(send_futures).await;
    }
}

One way would be to check the result of client.send(...) in the broadcast function and remove it from the Broadcaster.inner.clients vec when there is an error.

e.g.

    pub async fn broadcast(&self, event: &str) {
        let clients = self.inner.lock().unwrap().clients.clone();
        println!("number of clients before broadcast: {}", clients.len());
        let send_futures = clients.iter().map(|client| client.send(event.to_string()));

        let results = futures::future::join_all(send_futures).await;
        let mut results_iter = results.iter();
        let mut lock = self.inner.lock().unwrap();
        // only retain clients that did not return an error when trying to send
        lock.clients.retain(|_| results_iter.next().unwrap().is_ok());
        println!("number of clients after broadcast: {}", lock.clients.len());
    }

(this depends on join_all returning the futures in the same order as they were given, which I believe is the case, but if it can't be trusted some other way of iterating over the results may be needed).

I'm no expert on networking, but in my experience detecting network disconnections is not trivial (someone more knowledgeable might be able to expand on this), so I think the only reliable way is to check for errors when sending.

1 Like

Thank you very much. This was more or less my idea that I'm trying to implement.

I have a problem though: I would like to use tokio::spawn so as not to block the thread when I send an event, but I'm getting the following error, how do I solve it?

pub async fn broadcast(&self, event: &str) {
    tokio::spawn(async move {
        let clients = self.inner.lock().unwrap().clients.clone();
        println!("number of clients before broadcast: {}", clients.len());
        let send_futures = clients.iter().map(|client| client.send(event.to_string()));

        let results = futures::future::join_all(send_futures).await;
        let mut results_iter = results.iter();
        let mut lock = self.inner.lock().unwrap();
        // only retain clients that did not return an error when trying to send
        lock.clients.retain(|_| results_iter.next().unwrap().is_ok());
        println!("number of clients after broadcast: {}", lock.clients.len());
    });
}
error[E0521]: borrowed data escapes outside of method
   |
47 |       pub fn broadcast(&self, event: &str) {
   |                        -----
   |                        |
   |                        `self` is a reference that is only valid in the method body
   |                        let's call the lifetime of this reference `'1`
...
64 | /         tokio::spawn(async move {
65 | |             let clients = self.inner.lock().unwrap().clients.clone();
66 | |             println!("number of clients before broadcast: {}", clients.len());
67 | |             let send_futures = clients.iter().map(|client| client.send(event.to_string()));
79 | |         });
   | |          ^
   | |          |
   | |__________`self` escapes the method body here
   |            argument requires that `'1` must outlive `'static`

I guess what you actually mean is that you want to avoid having to await on the broadcast called within the send_message handler, allowing the send_message request to return before doing the message sending. Using the term thread here is not really accurate, as you're not blocking a thread when using await.

The problem with your code is that you can't keep the reference to self inside the tokio::spawn (since it stops existing once the broadcast function returns).

The easiest answer is probably to change the send_message function to clone the app state and then call broadcast in a tokio::spawn closure.

async fn send_message(State(app_state): State<Arc<AppState>>) -> Html<&'static str> {
    let app_state = app_state.clone();
    tokio::spawn(async move {
        app_state.broadcaster.broadcast("message").await;
    });

    Html("Message sent")
}

another way would be to change your inner Mutex<BroadcasterInner> into an Arc<Mutex<BroadcasterInner>> and clone that at the start of the broadcast function and use the cloned value instead of self.inner inside the tokio::spawn closure.

1 Like

It works. Thanks!

I have a doubt now: a guy on SO answered:

you actually want a tokio::sync::broadcast channel. Using a broadcast channel will generally just work as you don't have to do anything to handle a closed connection because there is only one sender.

when the connection is closed, the response value gets dropped. This will transitively drop rx, which closes the channel. The broadcast function will receive an error from the send call for senders whose receivers are all gone, so you can react to this error in broadcast to clean up orphaned senders.

What do you think about this? Isn't the same using mpsc instead of broadcast?

Yes, from the looks of it, replacing your BroadcasterInner struct with a Sender from tokio::sync::broadcast and using the subscribe method on that to get your rx for your /sse stream rather than your own add_client method should make things a fair bit simpler.

TIL about tokio::sync::broadcast as well!

I'm having a little trouble understanding.

What do you mean by:

using the subscribe method on that to get your rx for your /sse stream rather than your own add_client method

In your existing code you have the following in your sse_handler function

let rx = app_state.broadcaster.add_client().await;

Instead, that could be a call to tokio::sync::broadcast::Sender::subscribe. Of course, the rest of your code also needs to be adapted to use the tokio::sync::broadcast module. But I don't think you need me to do that for you (at least try to do it yourself first and ask more specific questions if you have trouble!)

So you're suggesting to have one single broadcast channel instead of Vec<broadcast::Sender<String>>? And to subscribe() to it in the sse_handler route?

From this:

pub struct BroadcastEmitter {
    clients: Vec<broadcast::Sender<String>>,
}

to this:

pub struct BroadcastEmitter {
    channel: broadcast::Sender<String>,
}

?

Exactly!

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.