How to create Axum handlers for producer and timeout-able consumer in lock-free style?

Hi team. I'm new to Rust :smiley: . The scenario I meet is I need to create two http endpoints, one for producer (POST) and one for consumer (GET) , say the consumable item is a u64. If consumer doesn't get the u64 in time (say, 3 sec), then the GET endpoint should response with 408 TIMEOUT immediately. So what I end up is to create 2 queues, one for cache of produced items (Queue A) and the other one is for the waiting consumers (Queue B, actually a HashMap of [id:u64, tokio::sync::mpsc::UnboundedSender<u64>]). Both queues are wrapped by tokio::sync::RwLock/Mutex.

For the GET endpoint, it first checks if there's item in Queue A, if yes then pop the first item in A and response, otherwise it will generate a wait task id and an unbounded u64 mpsc channel that compose tuple (id, sender) and insert to Queue B. Then it use tokio::time::timeout to wait for the channel receiver, if nothing get from receiver in time then response with 408 TIMEOUT.

For the POST endpoint, if there's no waiting task in the waiting queue, then it pushes u64 to Queue A, otherwise it iterates and get the first task in Queue B, and send the u64 to the channel, and later remove the task from Queue B (hashmap .remove())

I found my implementation always get stuck when concurrency level increases. I use Apache Bench (ab) to stress and sometimes it runs well until thousands of requests and sometimes stucks immediately. Also the timeout scenario is not as expected that almost all requests wait for more than 3 secs. So I guess there's severe race condition in my code for acquiring lock (mutex/rwlock).

May I know if there's any smart way (tools, crates...) to implement this kind of scenario, and if possible could you please point out the possible deadlock in my code? Thank you!

use std::collections::{HashMap, VecDeque};
use std::str::FromStr;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering::SeqCst;
use std::time::Duration;

use axum::body::{Body, Bytes};
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::Response;
use axum::Router;
use axum::routing::{get, post};
use log::info;
use log::LevelFilter::Info;
use simple_logger::SimpleLogger;
use tokio::sync::{Mutex, RwLock};
use tokio::sync::mpsc::UnboundedSender;

struct AppState {
    immediate_queue: Mutex<VecDeque<u64>>,
    waiting_queue: RwLock<HashMap<u64, UnboundedSender<u64>>>,
    waiting_task_id_generator: AtomicU64,
}

#[tokio::main]
async fn main() {
    SimpleLogger::new()
        .with_local_timestamps()
        .with_level(Info)
        .init()
        .unwrap();
    info!("Main started");

    let arc_app_state = Arc::new(AppState {
        immediate_queue: Mutex::new(VecDeque::new()),
        waiting_queue: RwLock::new(HashMap::new()),
        waiting_task_id_generator: AtomicU64::new(0),
    });

    let router = Router::new()
        .route("/producer", post(producer))
        .route("/consumer", get(consumer))
        .with_state(arc_app_state);

    let listener = tokio::net::TcpListener::bind("127.0.0.1:8000").await.unwrap();
    info!("Listening on {:?}", listener);
    axum::serve(listener, router).await.unwrap();
}


async fn producer(
    State(app_state): State<Arc<AppState>>,
    bytes: Bytes,
) -> Response {
    if let Some(number) = String::from_utf8(bytes.to_vec()).ok()
        .and_then(|num_str| u64::from_str(num_str.as_str()).ok())
    {
        if let Some(task) = app_state.waiting_queue.read().await.iter().next() {
            let _ = task.1.send(number);
            let id = task.0.clone();
            info!("{} sent to waiting queue", number);
            let app_state_clone = app_state.clone();
            tokio::spawn(async move {
                app_state_clone.waiting_queue.write().await.remove(&id);
            });
        } else {
            app_state.immediate_queue.lock().await.push_back(number);
            info!("{} sent to immediate queue", number);
        }
    }
    Response::builder()
        .status(StatusCode::OK)
        .body(Body::empty())
        .unwrap()
}

async fn consumer(
    State(app_state): State<Arc<AppState>>,
) -> Response {
    return if let Some(num) = app_state.immediate_queue.lock().await.pop_front() {
        info!("{} consumed in immediate queue", num);
        Response::builder()
            .status(StatusCode::OK)
            .body(Body::from(format!("{}", num)))
            .unwrap()
    } else {
        let id = app_state.waiting_task_id_generator.fetch_add(1, SeqCst);
        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<u64>();
        let _ = app_state.waiting_queue.write().await.insert(id, tx);
        let timeout = tokio::time::timeout(Duration::from_secs(3), rx.recv()).await;
        match timeout {
            Ok(Some(num)) => {
                info!("{} consumed in waiting queue. wait task id={}", num, id);
                Response::builder()
                    .status(StatusCode::OK)
                    .body(Body::from(format!("{}", num)))
                    .unwrap()
            }
            _ => {
                info!("wait but get nothing, wait task id={}", id);
                let app_state = app_state.clone();
                tokio::spawn(async move {
                    app_state.waiting_queue.write().await.remove(&id);
                });
                Response::builder()
                    .status(StatusCode::REQUEST_TIMEOUT)
                    .body(Body::from(format!("wait but get nothing, wait task id={}", id)))
                    .unwrap()
            }
        }
    };
}

Cargo.toml:

[package]
name = "axum-playground"
version = "0.1.0"
edition = "2021"

[dependencies]
axum = "0.7"
log = "0.4"
tokio = { version = "1.0", features = ["full"] }
simple_logger = "4.2.0"

How I test:

$ ab -n 100000 -c 8 http://127.0.0.1:8000/consumer
$ echo -n 123 > testpost
$ ab -n 110000 -p testpost -c 8 http://127.0.0.1:8000/producer

Ok I think I simply need some mpmc channels like flume or async-channel. Will do further test and get back here.

Succeed with flume/async_channel/crossbeam_channel

use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;

use axum::body::{Body, Bytes};
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::Response;
use axum::Router;
use axum::routing::{get, post};
use log::{error, info};
use log::LevelFilter::Info;
use simple_logger::SimpleLogger;

struct AppState {
    receiver: Arc<flume::Receiver<u64>>,
    sender: Arc<flume::Sender<u64>>,
}

#[tokio::main]
async fn main() {
    SimpleLogger::new()
        .with_local_timestamps()
        .with_level(Info)
        .init()
        .unwrap();
    info!("Main started");
    let (tx, rx) = flume::unbounded::<u64>();
    let arc_app_state = Arc::new(AppState {
        receiver: Arc::new(rx),
        sender: Arc::new(tx),
    });

    let router = Router::new()
        .route("/producer", post(producer))
        .route("/consumer", get(consumer))
        .with_state(arc_app_state);

    let listener = tokio::net::TcpListener::bind("127.0.0.1:3721").await.unwrap();
    info!("Listening on {:?}", listener);
    axum::serve(listener, router).await.unwrap();
}


async fn producer(
    State(app_state): State<Arc<AppState>>,
    bytes: Bytes,
) -> Response {
    let sender = app_state.sender.clone();
    if let Some(number) = String::from_utf8(bytes.to_vec()).ok()
        .and_then(|num_str| u64::from_str(num_str.as_str()).ok())
    {
        match sender.send(number) {
            Ok(_) => {
                info!("{} sent to immediate queue", number);
            }
            Err(e) => {
                error!("{} failed to be sent: {:?}", number, e);
            }
        }
    }
    Response::builder()
        .status(StatusCode::OK)
        .body(Body::empty())
        .unwrap()
}

async fn consumer(
    State(app_state): State<Arc<AppState>>,
) -> Response {
    let receiver = app_state.receiver.clone();
    let timeout = tokio::time::timeout(Duration::from_secs(3), async {
        receiver.recv()
    }).await;
    match timeout {
        Ok(Ok(num)) => {
            info!("{} consumed in waiting queue.", num);
            Response::builder()
                .status(StatusCode::OK)
                .body(Body::from(format!("{}", num)))
                .unwrap()
        }
        _ => {
            info!("wait but get nothing.");
            Response::builder()
                .status(StatusCode::REQUEST_TIMEOUT)
                .body(Body::from("wait but get nothing."))
                .unwrap()
        }
    }
}
1 Like

crossbeam_channel

Please be aware that crossbeam is not interchangeable with the other channels because it is a non-async channel.