Pass mutable async function as argument

Greetings!

Any way to write something like this?

use std::fmt::Display;
use std::future::Future;
use std::time::UNIX_EPOCH;

struct Client {
    error_counter: usize,
}

impl Client {
    fn new() -> Self {
        Self { error_counter: 0 }
    }

    async fn try_open(&mut self) -> Result<(), usize> {
        let is_err = std::time::SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .unwrap()
            .as_secs()
            % 2;

        if is_err == 0 {
            Ok(())
        } else {
            self.error_counter += 1;
            Err(self.error_counter)
        }
    }
}

async fn try_do<F, Fut, T, E>(mut fun: F) -> Result<T, E>
where
    F: FnMut() -> Fut,
    Fut: Future<Output = Result<T, E>>,
    E: Display,
{
    loop {
        match fun().await {
            Ok(ret) => break Ok(ret),
            Err(err) => {
                eprintln!("Got error: `{err}`.");
            }
        }
    }
}

#[tokio::main]
async fn main() {
    let mut client = Client::new();

    let _ret = try_do(|| async { client.try_open().await }).await;
}

Only way to get rid of copy-paste which I see is convert try_do into macro.

In a concurrent context, it's generally a pain to work with properly &mutable references. Try converting to shared references and interior mutability:

struct Client {
    error_counter: Cell<usize>,
}

impl Client {
    fn new() -> Self {
        Self { error_counter: Cell::new(0) }
    }

    async fn try_open(&self) -> Result<(), usize> {
        let is_err = std::time::SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .unwrap()
            .as_secs()
            % 2;

        if is_err == 0 {
            Ok(())
        } else {
            let new = self.error_counter.get() + 1;
            self.error_counter.set(new);
            Err(new)
        }
    }
}

async fn try_do<'a, F, Fut, T, E>(fun: F) -> Result<T, E>
where
    F: Fn() -> Fut,
    Fut: Future<Output = Result<T, E>> + 'a,
    E: Display,
{
    loop {
        match fun().await {
            Ok(ret) => break Ok(ret),
            Err(err) => {
                eprintln!("Got error: `{err}`.");
            }
        }
    }
}
1 Like

error_counter: usize is just an example. I have some Client which have &mut self methods.

Then replace Cell with atomics or RefCell or Mutex or RwLock.

It look like this:

let client_guad = client.lock_owned().await;
<...>
tokio::task::spawn(async move {
    if try_do(|| client.try_open()) {
        *client_guard = Some(client);
    }
});

to use Mutex I should put Client into Mutex<Option<Client>>, do unnecessary locking and unwrapping. If I inline function it work as expected. I can use macros instead of function, but macros, imho, not intended to use for such purposes.

You can always wrap the more complex mutability with an interior Mutex. I've been doing this for all my tokio work.. I don't like it, but it gives me the cleanest looking fn main. I typically keep refactoring when I find alternate ways.

In tokio you'll have to decide whether you want to use tokio-Mutex or std-Mutex, as you should NOT hold the mutex across sub-async boundaries; but if you have a 2 step IO operation it's better to use tokio-Mutex than anything else.. Make sure ANY IO (and definitely any async/await barriers) you release and re-acquire the mutex. Personally I've been putting everything in dedicated scopped threads, as it's faster and easier to reason about.

macro more reasonable choice =)

Why do you think the locking is unnecessary? The runtime is free to move your async code to arbitrary threads, so if you want shared mutability, you need to synchronize it.

Because I want function like this:

macro_rules! try_do {
    ($fun:expr) => {{
        loop {
            match $fun.await {
                Ok(ret) => break Ok(ret),
                Err(2) => break Err(2),
                Err(err) => {
                    eprintln!("Got error: `{err}`.");
                }
            }
        }
    }};
}

#[tokio::main]
async fn main() {
    let mut client = Client::new();

    let _ret = try_do!(client.try_open());
}