Calling an argument, which is a closure, with a closure as argument

I have a function foo that takes a closure as argument. I want to call this closure within foo and pass another closure to that closure as argument. The following code works:

fn foo<C>(mut closure: C)
where
    C: FnMut(Box<dyn FnMut()>),
{
    closure(Box::new(|| println!("Hello!")));
}

fn main() {
    foo(|mut callback| {
        println!("Calling callback.");
        callback();
        println!("Called callback.");
    });
}

(Playground)

But can this also be done without Box<dyn …>? If yes, how? If no, does this require a Box by principle or is this just something that's missing in Rust and could be added in future?

I tried using TAITs, but without success yet (Playground).

You don't need a box, you could take a &mut dyn FnMut() :wink:.

You want something like a for<A: FnMut()> FnMut(A), but

  • for doesn't support non-lifetimes
    • Although this may get relaxed post-GAT
  • for doesn't support inline bounds

And on top of that,

  • Rust doesn't support generic closures

So achieving your goal without dyn seems a long way off.

3 Likes

I found the following workaround:

struct Callback {
    state: i32,
}

impl Callback {
    fn call(&mut self) {
        self.state += 1;
        println!("Hello!");
    }
}

fn foo<C>(mut closure: C)
where
    C: FnMut(&mut Callback),
{
    let mut callback = Callback { state: 0 };
    closure(&mut callback);
    println!("Callback used {} times.", callback.state);
}

fn main() {
    foo(|callback| {
        println!("Calling callback.");
        callback.call();
        println!("Called callback.");
    });
}

(Playground)

Output:

Calling callback.
Hello!
Called callback.
Callback used 1 times.

But the callback isn't a closure then (but requires to store the state manually in a struct). Also calling isn't as nice (as I have to write callback.call() instead of callback()).

Oh, that's a new concept for me. I take that using &mut dyn instead of Box<dyn> will avoid heap allocation, right?

I tried that too and yeah, the compiler told me that I can only use lifetimes in for<…>.

Okay, I didn't even get that far :sweat_smile:.

Yes -- you're probably familiar with how slices work, where [T] is dynamically sized, and & (or &mut are wide pointers that have a pointer to the memory in one usize-sized pointer, and the number of elements in the slice in another usize-sized field. The slice itself might be on the heap, but doesn't have to be -- like references generally, they can point anywhere.

You can also have a Box<[T]> which works pretty much the same... only the slice is definitely on the heap and is owned by the Box. Perhaps a more common, but similar, pattern is an Arc<str>.

dyn FnMut() (or any other dyn Trait) work similarly, even though they are dynamically sized for a different reason (the underlying type-erased types could have different sizes). Instead of the number of elements, the metadata field of the wide pointer is a pointer to a vtable (that also includes the size of the erased type). Just like slices, you can own the unsized data (Box<dyn Trait>) or you can just have a reference to it (&dyn Trait, &mut dyn Trait). It the second case, like with references generally, the object could be anywhere and doesn't have to be on the heap.

So if you have a F: FnMut(), you can coerce it to a dyn FnMut(), but have to keep this unsized type behind some kind of pointer -- either a &mut dyn FnMut() without moving it, or a Box<dyn FnMut()> on the heap.

2 Likes

You can make the TAIT version work if you shuffle things around so the C type parameter doesn't get involved
Playground

I'm not sure whether the original way you wrote that version should work or not. The closure doesn't actually capture anything connected to C, but there may be a good reason to act like it does there.

1 Like

I tried to see if I can use this to modify variables local to foo. It seems to work too, but requires capturing lifetimes: (Playground) The syntax is a bit unwieldy.

Thanks @quinedot for the detailed explanation. I might go for the &mut dyn approach (but still experimenting a bit).

Is that similar to how boxed slices work? Except it's a str here, and it's an Arc instead of Box? I take that if the last Arc is dropped, the memory is freed? I didn't figure out yet how to create an Arc<str> myself.

Yes to all of that. There are From<String> and From<&str> implementations.

1 Like

Did I mention I actually would like this to work with async too? :speak_no_evil:

Again, I found a solution with Box (and Pin)… (and Arc and Mutex :see_no_evil:)

use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};

async fn foo<C, Fut>(mut closure: C) -> i32
where
    C: FnMut(Box<dyn FnMut() -> Pin<Box<dyn Future<Output = ()>>>>) -> Fut,
    Fut: Future<Output = ()>,
{
    let counter: Arc<Mutex<i32>> = Arc::new(Mutex::new(0));
    let counter2 = counter.clone();
    closure(Box::new(move || {
        let counter3 = counter2.clone();
        Box::pin(async move {
            *counter3.lock().unwrap() += 1;
            println!("Hello!");
            tokio::task::yield_now().await;
        })
    })).await;
    let x = *counter.lock().unwrap();
    x
}

#[tokio::main]
async fn main() {
    let retval = foo(|mut callback| {
        async move {
            println!("Calling callback.");
            callback().await;
            println!("Called callback.");
        }
    }).await;
    println!("Callback used {} times.", retval)
}

(Playground)

Output:

Calling callback.
Hello!
Called callback.
Callback used 1 times.

Is there any way to make this mess more clean?

Edit: Perhaps this is pushing things too far and macros are a better alternative for me.


To see my actual use case, click here.
pub struct FreqShift {
    sender: Sender<Chunk<Complex<f32>>>,
    receiver: Receiver<Chunk<Complex<f32>>>,
    freq: watch::Sender<(isize, isize)>,
}

impl FreqShift {
    pub fn new(numerator: isize, denominator: isize) -> Self {
        let sender = Sender::new(16); // TODO
        let receiver = Receiver::<Chunk<Complex<f32>>>::new();
        let (tx, mut rx) = watch::channel((numerator, denominator));
        let mut input = receiver.stream();
        let output = sender.clone();
        spawn(async move {
            const TAU: f32 = std::f32::consts::TAU;
            assert!(denominator >= 0, "denominator must be non-negative");
            let calculate_phase = |(numerator, denominator): (isize, isize)| {
                let mut phase_vec: Vec<Complex<f32>> =
                    Vec::with_capacity(denominator.try_into().unwrap());
                let mut i: isize = 0;
                for _ in 0..denominator {
                    let (im, re) = <f32>::sin_cos(i as f32 / denominator as f32 * TAU);
                    phase_vec.push(Complex::new(re, im));
                    i += numerator;
                    i %= denominator;
                }
                let mut phase_iter = phase_vec.into_iter().cycle();
                move || phase_iter.next().unwrap()
            };
            let mut next_phase = calculate_phase((numerator, denominator));
            let mut buf_pool = ChunkBufPool::<Complex<f32>>::new();
            loop {
                match input.recv().await {
                    Ok(input_chunk) => {
                        if rx.has_changed().unwrap() {
                            let (numerator, denominator) = *rx.borrow_and_update();
                            next_phase = calculate_phase((numerator, denominator));
                        }
                        let mut output_chunk = buf_pool.get_with_capacity(input_chunk.len());
                        for sample in input_chunk.iter() {
                            output_chunk.push(sample * next_phase());
                        }
                        output.send(output_chunk.finalize());
                    }
                    Err(err) => {
                        output.forward_error(err);
                        if err == RecvError::Closed {
                            return;
                        }
                    }
                }
            }
        });
        FreqShift {
            sender,
            receiver,
            freq: tx,
        }
    }
    pub fn set_freq(&self, numerator: isize, denominator: isize) {
        self.freq.send_replace((numerator, denominator));
    }
}

impl Consumer<Chunk<Complex<f32>>> for FreqShift {
    fn receiver(&self) -> &Receiver<Chunk<Complex<f32>>> {
        &self.receiver
    }
}

impl Producer<Chunk<Complex<f32>>> for FreqShift {
    fn connector(&self) -> SenderConnector<Chunk<Complex<f32>>> {
        self.sender.connector()
    }
}

(Code is yet under development and just a rough sketch yet.)

I would like to avoid the boilerplate of:

loop {
    match input.recv().await {
        Ok(input_chunk) => { /* … */ }
        Err(err) => {
            output.forward_error(err);
            /* … */
            if err == RecvError::Closed {
                return;
            }   
        }   
    }   
}   

I thought to make a function which takes two async closures. But I end up in a mess of Box, dyn, Arc, move, async, Future, &mut, etc. Hence why I tried to break this problem down to a simpler toy example.

Anyway, like I said, I feel like complexity explodes, so maybe just live with the boilerplate and/or consider a macro. (Unless it's easier than I think.)