MQTT command registry with traits and serde

Hi!

I'm building a relay controller module that listens for commands over MQTT. I'm trying to separate out the command implementation from the MQTT event loop handling. I created an interface struct for the MQTT communication that handles the async event loop, and would let me register command handlers in the form of async functions. Each command is decoded into a type from bytes and the responses are encoded into bytes from task specific types as well. I used typetag to dynamically decode/encode inside the eventloop. I just can't figure out how to register async functions that take the decoded objects and return the data.

Here's a summary of what I have so far:


#[typetag::deserialize]
pub trait Request {}

#[typetag::serialize(tag = "success")]
pub trait ResponseSuccess {}

#[typetag::serialize(tag = "failure")]
pub trait ResponseFailure {}

type RequestHandler = Box<
    dyn Fn(
            Box<dyn Request>,
        ) -> Pin<
            Box<
                dyn Future<
                        Output = Result<
                            Box<dyn ResponseSuccess + Send + Sync>,
                            Box<dyn ResponseFailure + Send + Sync>,
                        >,
                    > + Send
                    + Sync,
            >,
        > + Send
        + Sync,
>;
type RequestHandlerRegistry = Arc<Mutex<HashMap<String, RequestHandler>>>;

pub struct Interface {
     handlers: RequestHandlerRegistry
}
impl Interface {
    pub async fn add_handler(&self, topic: &str, handler: RequestHandler) -> Result<(), ClientError> {
        ...
        Ok(())
    }

    pub async fn accept() -> Result<(), InterfaceError>{
        let handlers = self.handlers.clone().lock().await;
        let Some(request_handler) = handlers.get(&topic) else {
            return Ok(());
        };

        tokio::spawn(async move {
            let request = rmp_serde::decode::from_slice(&packet.payload);
            let response = request_handler(request).await
            let response = rmp_serde::encode::to_vec(&response)
            
            // Send response
            ...
    }
}

I'd like to be able to do something like


#[derive(Deserialize)]
struct GetRelayStatus(u8)

#[typetag::deserialize]
impl Request for GetRelayStatus {}

async fn get_status(cmd: GetRelayStatus) -> Result<bool, RelayInterfaceError>{
    // Toggle logic
    ...
    Ok(status)
}
...
interface.add_handler("relay/set", get_status);

Of course these types are wrong. How would something like this be possible?
I already made the app work by storing channels in the handler registry that take and recieve bytes. I just don't like that I have to implement serialization and deserialization inside the handlers.

Okay I think I figured it out. Here's the current best iteration for what I want my MQTT interface to look like. This example finally compiles.

use bytes::Bytes;
use std::sync::Arc;
use std::pin::Pin;
use std::collections::HashMap;
use std::time::Duration;
use tokio::sync::Mutex;
use serde::{Deserialize, Serialize};


#[derive(Clone)]
struct AsyncClient {}

impl AsyncClient {
    async fn publish(&self, topic: &str, payload: Bytes) {
        println!("Publishing to {topic}: {payload:?}");
    }
}

type RequestHandler =
    Arc<dyn Fn(Bytes) -> Pin<Box<dyn Future<Output = Bytes> + Send>> + Send + Sync>;

struct Interface {
    client: AsyncClient,
    handlers: Arc<Mutex<HashMap<String, RequestHandler>>>,
}

impl Interface {
    fn new() -> Self {
        Interface {
            client: AsyncClient {},
            handlers: Arc::new(Mutex::new(HashMap::new())),
        }
    }

    async fn run(&self) {
        loop {
            let Message {
                topic,
                response_topic,
                payload,
            } = recv_message().await;

            let handlers = self.handlers.lock().await;
            let handler = handlers.get(&topic).unwrap().clone();
            let client = self.client.clone();
            tokio::spawn(async move {
                let response_payload = handler(payload).await;
                client.publish(&response_topic, response_payload).await;
            });
        }
    }

    async fn add_handler<F, O, M, R>(&mut self, topic: &str, func: F)
    where
        M: for<'a> Deserialize<'a>,
        R: Serialize,
        O: Future<Output = R> + Send,
        F: Fn(M) -> O + Send + Sync + 'static,
    {
        let func_ref = Arc::new(func);
        let mut handlers = self.handlers.lock().await;
        handlers.insert(
            String::from(topic),
            Arc::new(move |request| {
                let func = func_ref.clone();
                Box::pin(async move {
                    let request = serde_json::from_slice(&request).unwrap();
                    let response = func(request).await;
                    serde_json::to_vec(&response).unwrap().into()
                })
            }),
        );
    }
}

#[derive(Deserialize)]
struct GetStatusCommand {
    device_id: u16,
}

#[derive(Serialize)]
struct Status {
    busy: bool,
}

// Dummy stuct. Represents the internal message type in rumqttc.
struct Message {
    topic: String,
    response_topic: String,
    payload: Bytes,
}

// Dummy function. Represents the internals of rumqttc.
async fn recv_message() -> Message {
    tokio::time::sleep(Duration::from_secs(1)).await;
    Message {
        topic: String::from("device/get_status"),
        response_topic: String::from("device/status"),
        payload: Bytes::from("{\"device_id\": 1}"),
    }
}

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let mut interface = Interface::new();
    interface
        .add_handler("device/get_status", async |command: GetStatusCommand| {
            println!("Getting status of device {} ", command.device_id);
            // Communicate with the device...
            tokio::time::sleep(Duration::from_secs(1)).await;
            Status { busy: true }
        })
        .await;

    let handle = tokio::spawn(async move {
        interface.run().await;
    });

    handle.await?;

    Ok(())
}

No custom traits needed. This accomplishes my original goal of implementing the encoding/decoding inside the interface, and being able to attach handlers with well defined request and response types from the outside. It also lets me spawn a thread for each message and call the registered handler within it. I'm still interested in better alternatives. Less Arcs, somehow using an async closure inside add_handler instead of the nested closure maybe, etc...

Summary of the changes:

  • RequestHandler takes in Bytes and returns Bytes and it's an internal detail
  • RequestHandler needs to be Arc instead of Box so I can clone it and pass it to the worker thread.
  • add_handler is a generic function and takes a generic function returning a Future. The compiler takes care of creating the specific implementations.
  • The definition of the encoding/decoding was moved into add_handler, but it's still carried out in the worker thread.

Of course there's lots of error handling and thread management in the actual code, but this is the core functionality. I also love that the caller of add_handler can decide how to move whatever processes the request into the interface. It can be something that implements Clone, an Arc or a direct implementation that requires no external context.

I still can't figure out something. I tried to implement add_handler slightly differently. I tried to bind the closure to a variable and pass that to Arc::new() to make it more readable, but I'm getting an error that I just don't understand.

expected `{closure@main.rs:177:25}` to return `Pin<Box<dyn Future<Output = Bytes> + Send>>`, but it returns `Pin<Box<{async block@src/main.rs:179:22: 179:32}>>`
expected struct `std::pin::Pin<std::boxed::Box<(dyn std::future::Future<Output = bytes::Bytes> + std::marker::Send + 'static)>>`
   found struct `std::pin::Pin<std::boxed::Box<{async block@src/main.rs:179:22: 179:32}>>`
required for the cast from `std::sync::Arc<{closure@src/main.rs:177:25: 177:46}>` to `Arc<dyn Fn(Bytes) -> Pin<Box<...>> + Send + Sync>`
the full name for the type has been written to '/Users/tamas/Projects/zeus-controller/target/debug/deps/zeus_controller-b8baa3bac15e702f.long-type-11016490378759920010.txt'
consider using `--verbose` to print the full type name to the consolerustcClick for full compiler diagnostic

Couldn't find anything when I looked into it. How is it possible that a simple intermediate variable assignment breaks this code?

This has happened to me before, so I'm not surprised. I don't know the exact reason, I only know that closure type inference is very sensitive to the context in which the closure is created. So declaring them before using them can change inference of their type. And it is currently not always possible to annotate the variable type correctly for closures.

If you search for "closure funnel" in this forum you'll see that a work around is to write a function that takes a closure parameter having the desired type, and that returns that parameter. Creating the closure inline in the call to the funnel will cause type inference to work properly.

For example:

Thanks. I've seen that trick, didn't known it's called funneling. I think this could be a good thing to point out in The Rust Programming Language docs, or maybe add a hint to the compiler message. It took me quite a bit of experimentation to figure out it's about the location of the closure. Otherwise what do you think of my proposed solution? Any improvements I could make?

There are effectively two closure inference modes:[1] standalone or bound driven. When you assign the closure to a variable, that's standalone. In this mode,

  • input arguments annotated with a lifetime-taking type constructor, e.g. &_ or Node<'_>, will result in the compiler attempting a compilation that accepts any lifetime
  • the lack of such an annotation will result in the compiler attempting a compilation that only takes one lifetime (probably: that takes one specific type)
  • the compiler simply won't attempt a return type that can vary by input lifetime
  • the captures and traits implemented are based on the closure body[2]
So for example
let ex1 = |_| {};
let ex2 = |_: &_| {};
let ex3 = |x| x;
let ex4 = |x: &_| x;
let ex5 = |x: &str| -> &str { x };

// Feed `str` inference
let local = "hi".to_owned();
ex1(&*local);
ex2(&*local);
ex3(&*local);
ex4(&*local);
ex5(&*local);
  • ex1 takes a single lifetime
  • ex2 takes any lifetime
  • ex3 takes and returns a single lifetime
  • ex4 and ex5 simply don't compile because the implementation the compiler attempt takes any lifetime, but tries to return a single lifetime (the same lifetime for all input lifetimes)

Bound driven is when the closure is defined in some place expecting a certain trait bound, typically an argument to some function.[3] In that case, the compiler will do it's best to implement the expected bound, overriding the standalone inference. That's where the "funnel" trick comes in; it's an identify function that overrides closure inference.

The bound has to be a Fn* trait specifically,[4] not a subtrait. Which limits the applicability of the funnel trick when the return type cannot be named (closures, futures). Which can be quite frustrating.

Playground with more complete examples.


  1. I don't know how this is actually implemented; this is my perspective from writing code ↩︎

  2. n.b. as we're about to see, " The mode is not affected by the code surrounding the closure" is incorrect ↩︎

  3. but also things like -> impl Fn(..) ↩︎

  4. or AsyncFn*, probably -- they're new and I didn't test ↩︎

2 Likes