Handler Registry for async functions

Hi Rustaceans,

I'm really struggling with Rust. I would like to create a handler registry to store async functions which modify some input struct. In JavaScript, that would be something like:

class Registry {
  constructor() {
    this.handlers = []
  }
  register(fn) {
    this.handlers.push(fn)
  }
  async run(input) {
    for (const handler of this.handler) await handler(input)
  }
}
const registry = new Registry()
registry.register(async (input) => { ... })
await registry.run(someInput)

Now in Rust, that seems pretty hard. I tried many things and came up with the following code, which doesn't compile because of the following error:

124 |         registry.register(async_fn);
    |                  ^^^^^^^^ within `impl Future<Output = Result<(), String>>`, the trait `Unpin` is not impleme
nted for `[async fn body@handlerx/src/lib.rs:120:69: 123:10]`

Could someone help me?

use async_trait::async_trait;
use std::marker::PhantomData;

pub struct CreateContext<T> {
    pub data: T,
}
pub struct ReadContext {}
pub enum Context<T> {
    Create(CreateContext<T>),
    Read(ReadContext),
    Update,
    Delete,
}

#[async_trait]
pub trait Handler<Input>: Send + Sync {
    async fn handle(&self, input: &mut Input) -> Result<(), String>;
}

pub struct HandlerRegistry<Input> {
    handlers: Vec<Box<dyn Handler<Input>>>,
}

impl<Input> HandlerRegistry<Input> {
    pub fn new() -> Self {
        HandlerRegistry { handlers: vec![] }
    }

    pub fn register<F, Fut>(&mut self, f: F)
        where
        F: Fn(&mut Input) -> Fut + Send + Sync + 'static,
        Fut: std::future::Future<Output = Result<(), String>> + Send + 'static + std::marker::Unpin,
        Input: Send + Sync + 'static,
        {
            let handler = AsyncFuncHandler::new(Box::new(f));
            self.handlers.push(Box::new(handler));
        }


    pub async fn exec(&self, input: &mut Input) -> Result<(), String> {
        for handler in &self.handlers {
            handler.handle(input).await?;
        }
        Ok(())
    }
}

struct AsyncFuncHandler<Input, Fut>
where
    Input: Send + Sync + 'static,
    Fut: std::future::Future<Output = Result<(), String>> + Send + 'static + std::marker::Unpin,
{
    f: Box<dyn Fn(&mut Input) -> Fut + Send + Sync>,
}

impl<Input, Fut> AsyncFuncHandler<Input, Fut>
where
    Input: Send + Sync + 'static,
    Fut: std::future::Future<Output = Result<(), String>> + Send + 'static + std::marker::Unpin,
{
    fn new(f: Box<dyn Fn(&mut Input) -> Fut + Send + Sync>) -> Self {
        AsyncFuncHandler { f }
    }
}

#[async_trait]
impl<Input, Fut> Handler<Input> for AsyncFuncHandler<Input, Fut>
where
    Input: Send + Sync + 'static,
    Fut: std::future::Future<Output = Result<(), String>> + Send + 'static + std::marker::Unpin,
{
    async fn handle(&self, input: &mut Input) -> Result<(), String> {
        (self.f)(input).await
    }
}


struct FuncHandler<Input, F, Fut>(F, PhantomData<Input>)
where
    F: Fn(&mut Input) -> Fut + Send + Sync + 'static,
    Fut: std::future::Future<Output = Result<(), String>> + Send + 'static;

impl<Input, F, Fut> FuncHandler<Input, F, Fut>
where
    F: Fn(&mut Input) -> Fut + Send + Sync + 'static,
    Fut: std::future::Future<Output = Result<(), String>> + Send + 'static,
{
    fn new(f: F) -> Self {
        FuncHandler(f, PhantomData)
    }
}

#[async_trait]
impl<Input, F, Fut> Handler<Input> for FuncHandler<Input, F, Fut>
where
    F: Fn(&mut Input) -> Fut + Send + Sync + 'static,
    Fut: std::future::Future<Output = Result<(), String>> + Send + 'static,
    Input: Send + Sync + 'static,
{
    async fn handle(&self, input: &mut Input) -> Result<(), String> {
        (self.0)(input).await
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[tokio::test]
    async fn can_build_registry() {
        let mut registry = HandlerRegistry::new();

        #[derive(Debug)]
        struct FooCtx {
            a: u32,
            b: String,
        }

        // how to register async function?
        async fn async_fn(input: &mut FooCtx) -> Result<(), String> {
            // async code here
            Ok(())
        }
        registry.register(async_fn);

        let mut input = FooCtx {
            a: 111,
            b: "bla".into(),
        };
        registry.exec(&mut input).await.unwrap();

        assert_eq!(input.a, 112);
        assert_eq!(input.b, "blaboooo");
    }
}

Futures returned from async functions cannot be unpinned, so they do not satisfy the Fut: std::marker::Unpin bound on HandlerRegistry::register(). The fix is to remove all the + std::marker::Unpins.

But after that, you'll have another issue: the async function's future is not 'static, since it needs to store the temporary &'_ mut Input reference that it receives as input. Unfortunately, the bounds for async functions which take arbitrary-lifetime inputs can't be written directly. Instead, you must either require the user function to return a Pin<Box<dyn Future<...> + Send>> directly, or add a new helper trait to express the correct bounds (Rust Playground):

use async_trait::async_trait;
use std::{future::Future, pin::Pin};

pub struct CreateContext<T> {
    pub data: T,
}
pub struct ReadContext {}
pub enum Context<T> {
    Create(CreateContext<T>),
    Read(ReadContext),
    Update,
    Delete,
}

#[async_trait]
pub trait Handler<Input>: Send + Sync {
    async fn handle(&self, input: &mut Input) -> Result<(), String>;
}

pub struct HandlerRegistry<Input> {
    handlers: Vec<Box<dyn Handler<Input>>>,
}

impl<Input> HandlerRegistry<Input> {
    pub fn new() -> Self {
        HandlerRegistry { handlers: vec![] }
    }

    pub fn register<F>(&mut self, f: F)
    where
        F: for<'a> AsyncFunc<'a, Input> + 'static,
        Input: Send + 'static,
    {
        let handler = AsyncFuncHandler::new(Box::new(move |input| Box::pin(f(input))));
        self.handlers.push(Box::new(handler));
    }

    pub async fn exec(&self, input: &mut Input) -> Result<(), String> {
        for handler in &self.handlers {
            handler.handle(input).await?;
        }
        Ok(())
    }
}

pub trait AsyncFunc<'a, Input>: Fn(&'a mut Input) -> Self::Fut + Send + Sync
where
    Input: 'static,
{
    type Fut: Future<Output = Result<(), String>> + Send;
}

impl<'a, F, Input, Fut> AsyncFunc<'a, Input> for F
where
    F: Fn(&'a mut Input) -> Fut + Send + Sync,
    Input: 'static,
    Fut: Future<Output = Result<(), String>> + Send,
{
    type Fut = Fut;
}

type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
type BoxAsyncFunc<Input> =
    Box<dyn Fn(&mut Input) -> BoxFuture<'_, Result<(), String>> + Send + Sync>;

struct AsyncFuncHandler<Input>(BoxAsyncFunc<Input>);

impl<Input> AsyncFuncHandler<Input> {
    fn new(f: BoxAsyncFunc<Input>) -> Self {
        AsyncFuncHandler(f)
    }
}

#[async_trait]
impl<Input> Handler<Input> for AsyncFuncHandler<Input>
where
    Input: Send,
{
    async fn handle(&self, input: &mut Input) -> Result<(), String> {
        (self.0)(input).await
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[tokio::test]
    async fn can_build_registry() {
        let mut registry = HandlerRegistry::new();

        #[derive(Debug)]
        struct FooCtx {
            a: u32,
            b: String,
        }

        async fn async_fn(input: &mut FooCtx) -> Result<(), String> {
            input.a += 1;
            input.b += "boooo";
            Ok(())
        }
        registry.register(async_fn);

        let mut input = FooCtx {
            a: 111,
            b: "bla".into(),
        };
        registry.exec(&mut input).await.unwrap();

        assert_eq!(input.a, 112);
        assert_eq!(input.b, "blaboooo");
    }
}

Note that this will not work with |input| async { ... } closures.

2 Likes

Thanks a lot @LegionMammal978 , that helps a lot!

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.