Dyn dispatch on multiple types

Hello,

The below code implements dyn dispatch based on the types of two arguments. I'm using a HashMap instead of a vtable. Unfortunately it appears to be about 25x slower than the vtable approach that you get from regular dyn, and requires a cumbersome startup-time initialization.

So my question is: is there a clever way I can actually tap into the vtable mechanism to get this behavior with compile-time creation and without the call-time overhead?

Or are there other things I could be doing better, from a runtime efficiency perspective?

use core::any::{Any, TypeId};
use std::collections::{HashMap};
use std::time::{Duration, SystemTime};

struct Registry {
    map: HashMap<(TypeId, TypeId), fn(&i32, &i32)>
}

impl Registry {

    fn new() -> Self {
        Self {
            map: HashMap::new()
        }
    }

    fn register_func<A: 'static, B: 'static>(&mut self, f: fn(&A, &B)) {

        let type_id_a = TypeId::of::<A>();
        let type_id_b = TypeId::of::<B>();
        
        let cast_f = unsafe{ std::mem::transmute(f) };
        self.map.insert((type_id_a, type_id_b), cast_f);
    }

    fn call_func(&self, a: &dyn Any, b: &dyn Any) {

        //Retrieve the function pointer
        let type_id_a = a.type_id();
        let type_id_b = b.type_id();
        let generic_f = self.map.get(&(type_id_a, type_id_b)).unwrap();

        let cast_a = unsafe{ &*(a as *const dyn Any as *const i32) };
        let cast_b = unsafe{ &*(b as *const dyn Any as *const i32) };
        generic_f(cast_a, cast_b);
    }
}

fn pair_func_a(a: &i32, b: &i32) {
    println!("func_a({}, {})", a, b);
}

fn pair_func_b(a: &i32, b: &char) {
    println!("func_b({}, {})", a, b);
}

fn pair_func_c(a: &char, b: &i32) {
    println!("func_c({}, {})", a, b);
}

fn main() {

    let mut resistry = Registry::new();

    resistry.register_func(pair_func_a);
    resistry.register_func(pair_func_b);
    resistry.register_func(pair_func_c);

    resistry.call_func(&5, &1);
    resistry.call_func(&5, &'b');
    resistry.call_func(&'c', &1);

    static_baseline();
    dyn_baseline();
    double_dyn_time();
}

//-----------------------------------------
//Benchmarks
//-----------------------------------------

fn static_func(a: &i32, b: &i32) {
    //Just to prevent the optimizer from eliminating the call
    if *a == 99999 {
        println!("single_static_a({}, {})", a, b);
    }
}

fn static_baseline() {

    let timer_start = SystemTime::now();

    for i in 0..1000000 {
        static_func(&i, &i);
    }

    println!("Static Baseline Time = {:?}", timer_start.elapsed().unwrap_or(Duration::new(0, 0))); 
}

trait SingleDyn : Any {
    fn dyn_func(&self, b: &i32);
}

impl SingleDyn for i32 {
    fn dyn_func(&self, b: &i32) {
        //Just to prevent the optimizer from eliminating the call
        if *self == 99999 {
            println!("single_dyn_a({}, {})", self, b);
        }
    }
}

fn single_dyn_dispatch(a: &dyn SingleDyn, b: &i32) {
    a.dyn_func(b); //I suspect the dyn dispatch might be being optimized away :-(
}

fn dyn_baseline() {

    let timer_start = SystemTime::now();

    for i in 0..1000000 {
        single_dyn_dispatch(&i, &i);
    }

    println!("Dyn Baseline Time = {:?}", timer_start.elapsed().unwrap_or(Duration::new(0, 0))); 
}

fn double_dyn_func(a: &i32, b: &i32) {
    //Just to prevent the optimizer from eliminating the call
    if *a == 99999 {
        println!("double_dyn_a({}, {})", a, b);
    }
}

fn double_dyn_time() {

    let mut resistry = Registry::new();
    resistry.register_func(double_dyn_func);
    
    let timer_start = SystemTime::now();

    for i in 0..1000000 {
        resistry.call_func(&i, &i);
    }

    println!("Double Dyn Time = {:?}", timer_start.elapsed().unwrap_or(Duration::new(0, 0))); 
}

Thank you.

Of course there is! Do it one at a time using regular dyn Trait. Playground.


Ps.: you shouldn't use unsafe for this. Unsurprisingly, your implementation is unsound (it has UB), because you are invoking functions through a pointer to the wrong type.

1 Like

In OO languages there's the visitor pattern to solve this problem, I guess you can adapt it for Rust too.

EDIT: On second thoughts it probably does not work (at least not straight forward) because Rust does not support function overloading.

The visitor pattern works just fine in Rust; e.g. Serde deserialization works with visitors all the way down. There's no need for function overloading; generics (parametric polymorphism) are strictly more powerful than "overloading" as conventionally defined (ad-hoc polymorphism).

3 Likes

I'm not convinced that it works with dynamic dispatch, because AFAIK you cannot call generic methods on trait objects. Or am I mistaken?

The "traditional" visitor pattern would look something like this:

trait Visitee {
    fn accept(&self, visitor: &dyn Visitor) {
        visitor.visit(self)
    }
}

trait Visitor {
    fn visit(&self, visitee: &i32);
    fn visit(&self, visitee: &str);
    // .. more types
}

i.e. it's a combination of dynamic dispatch and overload resolution.

If you look at how serde implements the visitor pattern, you will note that those visit methods have different names. E.g. there would be an visit_str and visit_i32 method. Overloading is not necessary here.

As for generic methods on trait objects, you can replace the generic method with a method that takes another trait object, like your accept method does.

2 Likes

Correct, but then you would just pass a dyn Trait.

I don't see how that could work, either: if the type of self is not known at compile time, then overload resolution can't kick in to choose the correct overload. If, however, the type is known at the point of implementing the method (i.e. it's not a default implementation but inside impl Visitor for ConcreteType), then one can trivially replace overloading with distinct functions, and call the appropriate one (visit_i32, visit_str, etc.). Hence, overloading doesn't seem necessary nor sufficient for implementing a visior.

Yes, it's not strictly necessary, but it's a lot more convenient if the accept-Method is just a single line and not a massive match that you have to extend every time you add a new type.

And leveraging method overloading for this is part of the original pattern, which is why I added "at least not straight forward".

I don't see why accept would become a match statement here unless the type you are implementing Visitee on is itself an enum.

1 Like

That's not helping at all, because the goal is to dispatch dynamic types to implementations for concrete types.

Yes you're right. I was confused by my own pseudocode because I thought I could use the default implementation which is of course not the case.

That does dispatch to concrete types; any trait method you call on a dyn Trait forwards to the underlying concrete type (unless you specifically implement a trait for dyn Trait). You can even get back to a value of the concrete type by using one of the Any::downcast_* methods if that is desired.

1 Like

Yes but it only dispatches one type, I don't see how this would help:

trait Visitee {
    fn accept(&self, visitor: &dyn Visitor);
}

trait Visitor {
    fn visit(&self, visitee: &dyn Visitee);
}

That would just be a ping-pong where there's always one concrete type and one dynamic type.

1 Like

Check the playground in my first reply. You always dispatch on the next dyn Trait argument. (The rest of the arguments can just be passed down the line if needed, they don't participate in dispatching.)

I've seen that, but then you just have two separate functions with each one having only one concrete parameter, not a single function having both concrete parameters. That might be enough for some use cases but the OP specifically wanted dispatch on both parameters in a single function.

I don't think this can be achieved using only trait objects.

1 Like

The code in the playground is the essential building block for multiple dispatch; you can trivially swap out the literal println!()s for trait methods implementing the classic visitor pattern as discussed above (i.e., simply by using separate functions instead of a "single"/overloaded method), like this. Here, both parameters of both visitor methods are statically typed, there is no dyn Trait involved at this level.

To better demonstrate that claim of mine, here is another version that contains functions with all argument types spelled out explicitly (which the trait methods forward to), as well as a single generic double_dispatch() function that hides the dyn Trait method calls and appears as a regular free function with two arguments.

3 Likes

Depending on your use case, you might be able to do something like this:

use core::any::{Any, TypeId};
use std::marker::PhantomData as PhD;
use std::collections::{HashMap};

trait CallDyn {
    type Output;
    fn try_call_dyn(&mut self, a:&dyn Any, b:&dyn Any)->Option<Self::Output>;
    unsafe fn call_dyn_unchecked(&mut self, a:&dyn Any, b:&dyn Any)->Self::Output;
}

trait CallPair {
    type A;
    type B;
    type Output;
    fn call_pair(&mut self, a:&Self::A, b:&Self::B)->Self::Output;
}

impl<F> CallDyn for F where
    F: CallPair,
    F::A: Any,
    F::B: Any,
{
    type Output = F::Output;
    fn try_call_dyn(&mut self, a:&dyn Any, b:&dyn Any)->Option<Self::Output> {
        Some(self.call_pair(a.downcast_ref()?, b.downcast_ref()?))
    }
    
    unsafe fn call_dyn_unchecked(&mut self, a:&dyn Any, b:&dyn Any)->Self::Output {
        self.call_pair(
            a.downcast_ref().unwrap_unchecked(),
            b.downcast_ref().unwrap_unchecked()
        )
    }
}

struct Callback<A,B,O,F:FnMut(&A,&B)->O>(F, PhD<fn(&A,&B)->O>);

impl<A,B,O,F> CallPair for Callback<A,B,O,F>
    where F:FnMut(&A,&B)->O
{
    type A=A;
    type B=B;
    type Output=O;
    fn call_pair(&mut self, a:&Self::A, b:&Self::B)->Self::Output {
        (self.0)(a,b)
    }
}


struct Registry {
    map: HashMap<TypeId, Box<dyn CallDyn<Output=()>>>
}

impl Registry {
    fn new() -> Self {
        Self {
            map: HashMap::new()
        }
    }

    fn register_func<A: 'static, B: 'static, F:FnMut(&A,&B) + 'static>(&mut self, f: F) {
        let type_id = TypeId::of::<(A,B)>();
        self.map.insert(type_id, Box::new(Callback::<A,B,(),_>(f, PhD)));
    }

    fn call_func<A:'static, B:'static>(&mut self, a: &A, b: &B)->Option<()> {
        //Retrieve the function pointer
        let type_id = TypeId::of::<(A,B)>();
        let generic_f = self.map.get_mut(&type_id)?;

        generic_f.try_call_dyn(a,b)
        // or:
        // Some( unsafe { generic_f.call_dyn_unchecked(a,b) } )
    }
}

//=========================================

fn pair_func_a(a: &i32, b: &i32) {
    println!("func_a({}, {})", a, b);
}

fn pair_func_b(a: &i32, b: &char) {
    println!("func_b({}, {})", a, b);
}

fn pair_func_c(a: &char, b: &i32) {
    println!("func_c({}, {})", a, b);
}

fn main() {

    let mut resistry = Registry::new();

    resistry.register_func(pair_func_a);
    resistry.register_func(pair_func_b);
    resistry.register_func(pair_func_c);

    resistry.call_func(&5, &1);
    resistry.call_func(&5, &'b');
    resistry.call_func(&'c', &1);
}
2 Likes

I was trying to figure out how your double-dispatch version works so I traced the functions with println!(), as in Rust Playground

To my surprise, in the impl of LevelOne for &'_ T, the overriding level_one() method is unused so its function body can be empty. Yet the presence of this method is required to compile the code. Without it, the compile fails as below:

error: reached the recursion limit while instantiating `<&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&...&&&&&i32 as LevelOne>::level_one`
 --> src/main.rs:8:24
  |
8 |         arg2.level_two(&self)
  |                        ^^^^^
  |
note: `LevelOne::level_one` defined here
 --> src/main.rs:6:5
  |
6 |     fn level_one(&self, arg2: &dyn LevelTwo) {
  |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

What's going on here? Why is this method required to compile while being unused?

Well, for starters, there's nothing surprising about type checking needing a definition even though it may never be called. You aren't allowed to write if false { foo(); } if the function foo doesn't exist, even though it's obvious that it's never actually called.

Why the type-level infinite recursion happens in this case is due to the default implementation of LevelOne::level_one(): it's simply arg2.level_two(&self), meaning that it passes a reference of type &&T, where T is a concrete type (i32 or String). But due to the signature of LevelTwo::level_two(), this in turn means that it tries to forward to the level_one() method on the underlying reference type, and that will repeat the whole process with &&&T, and the stack of references grows indefinitely.

To be honest, I didn't even notice initially that not implementing the method wouldn't compile. Even though it was optional, I went ahead and implemented it manually anyway, because the correct and robust thing to do for a forwarding implementation is to forward methods to the well-defined underlying type explicitly. So I didn't even think about not implementing this method.

As for why it isn't actually invoked: I guess it's just how method resolution works, nothing really fundamental. In fn double_dispatch(), both arguments are reference types, and since LevelOne::level_one() takes &self, I'm thinking method resolution can readily pick the implementation on the concrete type (T) instead of the forwarding impl (&T).


Note that all this dance with the forwarding implementation on references is only needed for convenience, so that the default impl of level_one() can compile. In a default impl, the compiler pretends it doesn't know the concrete type (in order to avoid post-monomorphization errors), so the default impl can't know if Self: Sized, therefore passing level_two(self) wouldn't compile. In contrast, restricting the method's signature with where Self: Sized excludes the very possibility of calling it on a trait object, so it doesn't compile either way. Thus, we need to make up a Sized value even if Self: !Sized, and that's what a reference does.

If you are willing to re-implement the level_one() method for each concrete type (which I wanted to avoid because it introduces the potential for a bug), then you can do away with the forwarding implementation entirely, like this.

4 Likes

Great answer, in particular nice of you for putting a version without the forwarding impl to compare. I was about to ask why you did the impl for &T. Realy appreciate it.