Axum / tokio shared state channel confusion

I seem to have a conflict between sharing state in Axum, and using a tokio::sync channel rather than a std:sync channel.

Tokio says I should be using a tokio::sync channel to communicate between synchronous and asynchronous code ( I understand because it's wrong to block in sync code ). However when I try to do this, I get compile errors from Axum.

Here's my program that works ( but I think it's blocking behaviour isn't quite right, ):

use axum::{
    extract::{Extension, Form, Multipart, Path, Query},
    response::Html,
    routing::get,
    AddExtensionLayer, Router,
};

use tower::ServiceBuilder;
use tower_cookies::{Cookie, CookieManagerLayer, Cookies};

use std::collections::HashMap;
use std::sync::{Arc, Mutex};

//use database::{Database};
use database::genquery::GenQuery;

use std::sync::mpsc;
use std::sync::mpsc::{SyncSender,Receiver};
use std::thread;

type GQ = Box<GenQuery>;

struct SharedState
{
  tx1 : mpsc::SyncSender<GQ>,
  rx2 : mpsc::Receiver<GQ>,
}

#[tokio::main]
async fn main() {

    let (tx1, rx1) : (SyncSender<GQ>,Receiver<GQ>) = mpsc::sync_channel(0);
    let (tx2, rx2) : (SyncSender<GQ>,Receiver<GQ>) = mpsc::sync_channel(0);

    // This is the server thread (synchronous).
    thread::spawn(move || 
    {
        let mut counter = 0;
        loop
        {
          let mut q = rx1.recv().unwrap();
          println!( "Server got query {}", q.path );
          q.headers = format!( "counter={}", counter );
          tx2.send( q ).unwrap();
          counter += 1;
        }
    });

    let state = Arc::new(Mutex::new(SharedState{ tx1, rx2 }));

    // build our application with a single route
    let app = Router::new()
        .route("/*key", get(my_get_handler).post(my_post_handler))
        .layer(
            ServiceBuilder::new()
                .layer(CookieManagerLayer::new())
                .layer(AddExtensionLayer::new(state)),
        );

    // run it with hyper on localhost:3000
    axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
        .serve(app.into_make_service())
        .await
        .unwrap();
}

async fn my_get_handler(
    // Wow, fancy pattern matching parameters...
    Extension(state): Extension<Arc<Mutex<SharedState>>>,
    Path(path): Path<String>,
    Query(params): Query<HashMap<String, String>>,
    cookies: Cookies,
) -> Html<String> {

    let result = tokio::task::block_in_place(|| "hello world!");

    let mut c = Cookie::new("MyCookie", "Hi George");
    c.set_path("/");
    cookies.add(c);

    let s = format!(
        "<p>Hi George result = {} path is '{}' params are {:?} <p>cookies are {:?}\
         <p>Simple form<form method=post><input name=n>\
         </form>\
         <p>Form with file<form method=post enctype=\"multipart/form-data\"><input name=n>\
         <input name=f type=file>\
         </form>\
         ",
        result, path, params, cookies
    );

    let mut q = Box::new(database::genquery::GenQuery::new());
    q.path = path;
    q.query = params;

    let received =
    {
      let state = state.lock().unwrap();
      state.tx1.send( q ).unwrap();
      let q = state.rx2.recv().unwrap();
      q.headers
    };
    println!("Got: {}", received);

    let s = format!( "{} received={}", s, received );

    Html(s)
}

async fn my_post_handler(
    // More fancy parameters...
    Extension(_state): Extension<Arc<Mutex<SharedState>>>,
    Path(path): Path<String>,
    Query(params): Query<HashMap<String, String>>,
    cookies: Cookies,
    form: Option<Form<HashMap<String, String>>>,
    mp: Option<Multipart>,
) -> Html<String> {
    
    let mut mpinfo = String::new();
    if let Some(mut mp) = mp {
        mpinfo += "multipart form!!";
        while let Some(field) = mp.next_field().await.unwrap() {
            let name = field.name().unwrap().to_string();
            let filename = match field.file_name() {
                Some(s) => s.to_string(),
                None => "No filename".to_string(),
            };
            let ct = match field.content_type() {
                Some(s) => s.to_string(),
                None => "".to_string(),
            };
            let mut datalen = 0;
            let mut text = "".to_string();
            if ct == "" {
                match field.text().await {
                    Ok(s) => text = s,
                    Err(_) => {}
                }
            } else {
                datalen = match field.bytes().await {
                    Ok(bytes) => bytes.len(),
                    Err(_) => 0,
                };
            }

            mpinfo += &format!(
                "<p>name is `{}` filename is `{}` ct is `{}` data len is {} bytes text is {}",
                name, filename, ct, datalen, text
            );
        }
    }
    let s = format!(
        "<p>Hi George path is '{}' and params are {:?}  <p>cookies {:?} form is {:?} mpinfo {}",
        path, params, cookies, form, mpinfo
    );
    Html(s)
}

This is trying to use tokio channel, but doesn't compile:

use axum::{
    extract::{Extension, Form, Multipart, Path, Query},
    response::Html,
    routing::get,
    AddExtensionLayer, Router,
};

use tower::ServiceBuilder;
use tower_cookies::{Cookie, CookieManagerLayer, Cookies};

use std::collections::HashMap;
use std::sync::{Arc, Mutex};

//use database::{Database};
use database::genquery::GenQuery;

use tokio::sync::mpsc;
use tokio::sync::mpsc::{Sender,Receiver};
use std::thread;

type GQ = Box<GenQuery>;

struct SharedState
{
  tx1 : mpsc::Sender<GQ>,
  rx2 : mpsc::Receiver<GQ>,
}

#[tokio::main]
async fn main() {

    let (tx1, rx1) : (Sender<GQ>,Receiver<GQ>) = mpsc::channel(1);
    let (tx2, rx2) : (Sender<GQ>,Receiver<GQ>) = mpsc::channel(1);

    // This is the server thread (synchronous).
    thread::spawn(move || 
    {
        let mut counter = 0;
        loop
        {
          let mut q = rx1.blocking_recv().unwrap();
          println!( "Server got query {}", q.path );
          q.headers = format!( "counter={}", counter );
          tx2.blocking_send( q );
          counter += 1;
        }
    });

    let state = Arc::new(Mutex::new(SharedState{ tx1, rx2 }));

    // build our application with a single route
    let app = Router::new()
        .route("/*key", get(my_get_handler).post(my_post_handler))
        .layer(
            ServiceBuilder::new()
                .layer(CookieManagerLayer::new())
                .layer(AddExtensionLayer::new(state)),
        );

    // run it with hyper on localhost:3000
    axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
        .serve(app.into_make_service())
        .await
        .unwrap();
}

async fn my_get_handler(
    // Wow, fancy pattern matching parameters...
    Extension(state): Extension<Arc<Mutex<SharedState>>>,
    Path(path): Path<String>,
    Query(params): Query<HashMap<String, String>>,
    cookies: Cookies,
) -> Html<String> {

    let result = tokio::task::block_in_place(|| "hello world!");

    let mut c = Cookie::new("MyCookie", "Hi George");
    c.set_path("/");
    cookies.add(c);

    let s = format!(
        "<p>Hi George result = {} path is '{}' params are {:?} <p>cookies are {:?}\
         <p>Simple form<form method=post><input name=n>\
         </form>\
         <p>Form with file<form method=post enctype=\"multipart/form-data\"><input name=n>\
         <input name=f type=file>\
         </form>\
         ",
        result, path, params, cookies
    );

    let mut q = Box::new(database::genquery::GenQuery::new());
    q.path = path;
    q.query = params;

    let received =
    {
      let mut state = state.lock().unwrap();
      state.tx1.send( q ).await;
      let q = state.rx2.recv().await.unwrap();
      q.headers
    };
    println!("Got: {}", received);

    let s = format!( "{} received={}", s, received );

    Html(s)
}

async fn my_post_handler(
    // More fancy parameters...
    Extension(_state): Extension<Arc<Mutex<SharedState>>>,
    Path(path): Path<String>,
    Query(params): Query<HashMap<String, String>>,
    cookies: Cookies,
    form: Option<Form<HashMap<String, String>>>,
    mp: Option<Multipart>,
) -> Html<String> {
    
    let mut mpinfo = String::new();
    if let Some(mut mp) = mp {
        mpinfo += "multipart form!!";
        while let Some(field) = mp.next_field().await.unwrap() {
            let name = field.name().unwrap().to_string();
            let filename = match field.file_name() {
                Some(s) => s.to_string(),
                None => "No filename".to_string(),
            };
            let ct = match field.content_type() {
                Some(s) => s.to_string(),
                None => "".to_string(),
            };
            let mut datalen = 0;
            let mut text = "".to_string();
            if ct == "" {
                match field.text().await {
                    Ok(s) => text = s,
                    Err(_) => {}
                }
            } else {
                datalen = match field.bytes().await {
                    Ok(bytes) => bytes.len(),
                    Err(_) => 0,
                };
            }

            mpinfo += &format!(
                "<p>name is `{}` filename is `{}` ct is `{}` data len is {} bytes text is {}",
                name, filename, ct, datalen, text
            );
        }
    }
    let s = format!(
        "<p>Hi George path is '{}' and params are {:?}  <p>cookies {:?} form is {:?} mpinfo {}",
        path, params, cookies, form, mpinfo
    );
    Html(s)
}

Error message:

C:\Users\pc\rust\axumtest>cargo run
   Compiling axumtest v0.1.0 (C:\Users\pc\rust\axumtest)
error[E0277]: the trait bound `fn(Extension<Arc<std::sync::Mutex<SharedState>>>, axum::extract::Path<std::string::String>, axum::extract::Query<HashMap<std::string::String, std::string::String>>, Cookies) -> impl Future {my_get_handler}: Handler<_, _>` is not satisfied
   --> src\main.rs:53:29
    |
53  |         .route("/*key", get(my_get_handler).post(my_post_handler))
    |                             ^^^^^^^^^^^^^^ the trait `Handler<_, _>` is not implemented for `fn(Extension<Arc<std::sync::Mutex<SharedState>>>, axum::extract::Path<std::string::String>, axum::extract::Query<HashMap<std::string::String, std::string::String>>, Cookies) -> impl Future {my_get_handler}`
    |
note: required by a bound in `axum::routing::get`
   --> C:\Users\pc\.cargo\registry\src\github.com-1ecc6299db9ec823\axum-0.3.2\src\routing\handler_method_routing.rs:101:8
    |
101 |     H: Handler<B, T>,
    |        ^^^^^^^^^^^^^ required by this bound in `axum::routing::get`

error[E0277]: the trait bound `fn(Extension<Arc<std::sync::Mutex<SharedState>>>, axum::extract::Path<std::string::String>, axum::extract::Query<HashMap<std::string::String, std::string::String>>, Cookies) -> impl Future {my_get_handler}: Handler<_, _>` is not satisfied
  --> src\main.rs:53:25
   |
53 |         .route("/*key", get(my_get_handler).post(my_post_handler))
   |                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Handler<_, _>` is not implemented for `fn(Extension<Arc<std::sync::Mutex<SharedState>>>, axum::extract::Path<std::string::String>, axum::extract::Query<HashMap<std::string::String, std::string::String>>, Cookies) -> impl Future {my_get_handler}`
   |
   = note: required because of the requirements on the impl of `Service<Request<_>>` for `axum::routing::MethodRouter<fn(Extension<Arc<std::sync::Mutex<SharedState>>>, axum::extract::Path<std::string::String>, axum::extract::Query<HashMap<std::string::String, std::string::String>>, Cookies) -> impl Future {my_get_handler}, _, _, MethodNotAllowed>`
   = note: 1 redundant requirements hidden
   = note: required because of the requirements on the impl of `Service<Request<_>>` for `axum::routing::MethodRouter<fn(Extension<Arc<std::sync::Mutex<SharedState>>>, axum::extract::Path<std::string::String>, axum::extract::Query<HashMap<std::string::String, std::string::String>>, Cookies, Option<Form<HashMap<std::string::String, std::string::String>>>, Option<Multipart>) -> impl Future {my_post_handler}, _, (Extension<Arc<std::sync::Mutex<SharedState>>>, axum::extract::Path<std::string::String>, axum::extract::Query<HashMap<std::string::String, std::string::String>>, Cookies, Option<Form<HashMap<std::string::String, std::string::String>>>, Option<Multipart>), axum::routing::MethodRouter<fn(Extension<Arc<std::sync::Mutex<SharedState>>>, axum::extract::Path<std::string::String>, axum::extract::Query<HashMap<std::string::String, std::string::String>>, Cookies) -> impl Future {my_get_handler}, _, _, MethodNotAllowed>>`

For more information about this error, try `rustc --explain E0277`.
error: could not compile `axumtest` due to 2 previous errors

I cannot figure out what I need to do. Help!

Perhaps I should say what I am trying to do in general terms...

(1) I want to have a synchronous "server task" which loops, it gets a message ( of type Box<GC> ) through a channel, and sends the message back, after modifying it a bit.

(2) The async handler task does some processing on the incoming http request, locks the shared state, sends the message to the "server task", waits for it to come back, unlocks the shared state, and returns ( with the response generated by the server task ).

[ I guess the handler task needs to somehow tell the tokio runtime "I am going to block", so the tokio runtime doesn't try to run other task on the thread. I think I saw something about this in async-std, but by now I am pretty confused as all this is totally new to me. ]

Hmm, looks like I should be using a tokio Mutex

But that doesn't solve my compile error problem.

Well that's strange, I switch to tokio::sync::Mutex, and everything is now working. I guess I did something else subtle at the same time, without realising. But anyway, I think it's now all working... even though I don't know what changed:

use axum::{
    extract::{Extension, Form, Multipart, Path, Query},
    response::Html,
    routing::get,
    AddExtensionLayer, Router,
};

use tower::ServiceBuilder;
use tower_cookies::{Cookie, CookieManagerLayer, Cookies};

use std::collections::HashMap;
use std::sync::{Arc};
use tokio::sync::Mutex;

//use database::{Database};
use database::genquery::GenQuery;

use tokio::sync::mpsc;
use tokio::sync::mpsc::{Sender,Receiver};
use std::thread;

type GQ = Box<GenQuery>;

struct SharedState
{
  tx1 : mpsc::Sender<GQ>,
  rx2 : mpsc::Receiver<GQ>,
}

#[tokio::main]
async fn main() {

    let (tx1, mut rx1) : (Sender<GQ>,Receiver<GQ>) = mpsc::channel(1);
    let (tx2, rx2) : (Sender<GQ>,Receiver<GQ>) = mpsc::channel(1);

    // This is the server thread (synchronous).
    thread::spawn(move || 
    {
        let mut counter = 0;
        loop
        {
          let mut q = rx1.blocking_recv().unwrap();
          println!( "Server got query {}", q.path );
          q.headers = format!( "counter={}", counter );
          let _x = tx2.blocking_send( q );
          counter += 1;
        }
    });

    let state = Arc::new(Mutex::new(SharedState{ tx1, rx2 }));

    // build our application with a single route
    let app = Router::new()
        .route("/*key", get(my_get_handler).post(my_post_handler))
        .layer(
            ServiceBuilder::new()
                .layer(CookieManagerLayer::new())
                .layer(AddExtensionLayer::new(state)),
        );

    // run it with hyper on localhost:3000
    axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
        .serve(app.into_make_service())
        .await
        .unwrap();
}

async fn my_get_handler(
    // Wow, fancy pattern matching parameters...
    Extension(state): Extension<Arc<Mutex<SharedState>>>,
    Path(path): Path<String>,
    Query(params): Query<HashMap<String, String>>,
    cookies: Cookies,
) -> Html<String> {

    let result = tokio::task::block_in_place(|| "hello world!");

    let mut c = Cookie::new("MyCookie", "Hi George");
    c.set_path("/");
    cookies.add(c);

    let s = format!(
        "<p>Hi George result = {} path is '{}' params are {:?} <p>cookies are {:?}\
         <p>Simple form<form method=post><input name=n>\
         </form>\
         <p>Form with file<form method=post enctype=\"multipart/form-data\"><input name=n>\
         <input name=f type=file>\
         </form>\
         ",
        result, path, params, cookies
    );

    let mut q = Box::new(database::genquery::GenQuery::new());
    q.path = path;
    q.query = params;

    let received =
    {
      let mut state = state.lock().await;
      let _x = state.tx1.send( q ).await;
      let q = state.rx2.recv().await.unwrap();
      q.headers
    };
    println!("Got: {}", received);

    let s = format!( "{} received={}", s, received );

    Html(s)
}

async fn my_post_handler(
    // More fancy parameters...
    Extension(_state): Extension<Arc<Mutex<SharedState>>>,
    Path(path): Path<String>,
    Query(params): Query<HashMap<String, String>>,
    cookies: Cookies,
    form: Option<Form<HashMap<String, String>>>,
    mp: Option<Multipart>,
) -> Html<String> {
    
    let mut mpinfo = String::new();
    if let Some(mut mp) = mp {
        mpinfo += "multipart form!!";
        while let Some(field) = mp.next_field().await.unwrap() {
            let name = field.name().unwrap().to_string();
            let filename = match field.file_name() {
                Some(s) => s.to_string(),
                None => "No filename".to_string(),
            };
            let ct = match field.content_type() {
                Some(s) => s.to_string(),
                None => "".to_string(),
            };
            let mut datalen = 0;
            let mut text = "".to_string();
            if ct == "" {
                match field.text().await {
                    Ok(s) => text = s,
                    Err(_) => {}
                }
            } else {
                datalen = match field.bytes().await {
                    Ok(bytes) => bytes.len(),
                    Err(_) => 0,
                };
            }

            mpinfo += &format!(
                "<p>name is `{}` filename is `{}` ct is `{}` data len is {} bytes text is {}",
                name, filename, ct, datalen, text
            );
        }
    }
    let s = format!(
        "<p>Hi George path is '{}' and params are {:?}  <p>cookies {:?} form is {:?} mpinfo {}",
        path, params, cookies, form, mpinfo
    );
    Html(s)
}

The axum-debug crate should help you diagnose issues like this. Don't mind that it depends on an older version of axum, I'm pretty sure it will still work (it really shouldn't depend on axum at all).

2 Likes

Yeah give axum-debug a try. I suspect whats happening is that you're holding a !Send type across an await thus making your future !Send which axum doesn't support. tokio's mutex guards are send which is probably why it works when you switch to that.

There are more details about debugging handler type errors here.

1 Like

Yes, thanks, I get it now. I didn't realise some code INSIDE the function (rather than in it's signature) could cause a type issue, that's a bit counter-intuitive. Send and async are all completely new to me, just getting through the learning curve.

1 Like

fyi I just published axum-debug 0.2 which works better with axum 0.3. Let me know if that helps.

I can confirm axum-debug does work. On the example I gave above it produces:

---------- Check ----------
    Checking axumtest v0.1.0 (C:\Users\pc\rust\axumtest)
error: future cannot be sent between threads safely
   --> src\main.rs:69:10
    |
69  | async fn my_get_handler(
    |          ^^^^^^^^^^^^^^ future returned by `my_get_handler` is not `Send`
    |
    = help: within `impl Future`, the trait `Send` is not implemented for `std::sync::MutexGuard<'_, SharedState>`
note: future is not `Send` as this value is used across an await
   --> src\main.rs:102:15
    |
100 |       let mut state = state.lock().unwrap();
    |           --------- has type `std::sync::MutexGuard<'_, SharedState>` which is not `Send`
101 |       state.tx1.send( q ).await;
102 |       let q = state.rx2.recv().await.unwrap();
    |               ^^^^^^^^^^^^^^^^^^^^^^ await occurs here, with `mut state` maybe used later
103 |       q.headers
104 |     };
    |     - `mut state` is later dropped here
note: required by a bound in `my_get_handler::{closure#0}::debug_handler`
   --> src\main.rs:69:10
    |
69  | async fn my_get_handler(
    |          ^^^^^^^^^^^^^^ required by this bound in `my_get_handler::{closure#0}::debug_handler`

error: could not compile `axumtest` due to previous error
Normal Termination
Output completed (1 sec consumed).

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.