How to implement functionality for a subset of objects that implement a trait

I have a generic trait for my ML library called Operator, defined like so:

pub trait Operator: Debug + TraitObjEq {
    fn process(&self, inp: Vec<Tensor>) -> Vec<Tensor>;
}

This implements a normal processing operation, where each operator takes in one or multiple Tensors and outputs one or multiple Tensors.

I keep a collection of these operators in a big datastructure as Box<dyn Operator> so they can be ran later.

Now I want to define another trait to have some shared functionality between a subset of them. Say, some of the operators use the GPU, so I'd like a way to mark if an op uses the GPU. Ideally I'd like to not touch the Operator trait, but if I need to make a change to it, it should be in a way that doesn't change the other Operator impls. Of course this can be done simply as:

pub trait Operator: Debug + TraitObjEq {
    fn process(&self, inp: Vec<Tensor>) -> Vec<Tensor>;
    fn uses_gpu(&self) -> bool {false}
}

and then just implement uses_gpu as true on the special ops. The issue is that this works for the uses_gpu subset, but I potentially want many subgroups of ops, which means for each subgroup I'd need to add a new trait function, which isn't workable. I'd rather just implement a seperate UsesGpu trait for all the GPU ops, but I don't know how to discern if a Box implements UsesGpu.

What would be the recommended way to do this?

I'm not sure what aspect you find not workable, so this might not be suitable either.

But you could have

pub trait Operator: Debug + TraitObjEq {
    fn process(&self, inp: Vec<Tensor>) -> Vec<Tensor>;
    fn uses_gpu(&self) -> Option<&dyn UsesGpu> { None };
    fn other_magic(&self) -> Option<&dyn OtherMagic> { None };
    // ...
}

pub trait UsesGpu: Operator { /* ... */ }
pub trait OtherMagic: Operator { /* ... */ }

If you're imagining some chain of

if let Some(magic1) = bx.magic1() { do_it_one_way(magic1) }
else if let Some(magic2) = bx.magic2() { do_it_another_way(magic2) }
// ...
else { do_it_the_default_way(&bx) }

you may be half-way to painfully emulating an enum.

2 Likes

Why not just make the Operator trait include both CPU and GPU methods which default to a no-op or giving an empty Vec or just panic!()ing? You can provide default implementations for the device selection and process method to save rewriting stuff on implementing Operator each time, as well.

pub enum Device { Cpu, Gpu }

pub trait Operator: Debug + TraitObjEq {
    // const GPU: bool = false;
    // const DEVICE: Device = Device::Cpu;
    fn device(&self) -> Device { Device::Cpu }
    
    fn process(&self, inp: Vec<Tensor>) -> Vec<Tensor> {
        // match Self::DEVICE {
        match self.device() {
            Device::Cpu => self.process_cpu(inp),
            Device::Gpu => self.process_gpu(inp),
        }
    }
    
    fn process_cpu(&self, inp: Vec<Tensor>) -> Vec<Tensor> {
        inp
        // or vec![] / unreachable!() / unimplemented!() / panic!()
    }
    
    fn process_gpu(&self, inp: Vec<Tensor>) -> Vec<Tensor> { inp }
}

And if you're worried about a public API becoming too permissive, there are neat options for sealing traits or methods thereof.

Here is a playground exploring some of the above ideas.

I guess my example was too simple, the reason I want to keep it as a separate trait is because I want to define shared functionality in subsets of operators even from other crates that the Operator trait may not reside in. So there may be a subset of operators that use Metal kernels, and they are in the metal crate. And there may be ones that use Cuda kernels, and they may be in the cuda crate. These crates should be able to be entirely third party, so I’m not able to have them change the main Operator trait.

I still don't understand the desired use and limitations. With a Vec<Box<dyn Operator>>, can't you just:

let gpu_ops = operator_vec.iter()
    .filter(|op| op.device() == Device::Gpu)
    .collect::<Vec<_>>();

Also, if you provide ALL possible methods on Operator with default implementations doesn't that free it up to be implemented by a third party? They would just need to select the proper methods to overwrite.

Edit: I guess if you want to be able to add a "device" then it is locked...

You nailed it at the end, these subsets are not just about devices, they can really be anything. Some ops use different precisions vs others, some use different devices, some work distributed across multiple devices, etc. Devices are just one type of subset.

The specific thing I'm designing for right now is to have a subset of the ops use metal, and define their kernels in this new trait function, call it MetalKernel.

trait MetalKernel {
    fn run_metal_kernel(input_buffers: &[Buffer]) -> Buffer {...}
}

and so I can loop through my Vec<Box<dyn Operator>> and, if the operator implements MetalKernel, I can call that and do some special stuff. Note this is all seperate from Operator and process().

One way to do this would be to blanket impl MetalKernel for all operators, and make it a noop, and just implement it specifically on the actual metal ops, but I think that would require specialization.

Does this look closer to what you're after? playground

I believe it meets the orphan trait rules and fully allows a third party to be able to define and call their own internal trait (or even just inherent methods, including any special functions needed) using a custom entry point which has access to all Operation runtime selection details. And it avoids specialization entirely.

The last bit is selecting specific custom ops, so why not just make the Custom variant hold a name or other means (u128 hash value?) of uniquely identifying its implementers from others? Is it TraitObjEq that is meant to be able to differentiate? That could have a method which takes an Option<IDValue> for selecting subsets.

Thanks for the playground! Let me try to rephrase to get away from these device analogies.

I have a collection of Animals, and an Animal trait. All animal structs implement this animal trait, and so I store them all as Vec<Box<dyn Animal>>. Some of those animals have a bark() function. They implement the Bark trait. What I want to do is something like this:

let objects: Vec<Box<dyn Animal>> = vec![Box::new(Dog), Box::new(Cat), Box::new(Wolf)];
for obj in objects {
    if let Some(barker) = obj.downcast_ref::<dyn Bark>() {
        barker.bark();
    }
}

In this example, only Dog and Wolf implement Bark, so it would only get called on those two. Since Cat doesn't implement Bark, it wouldn't get called.

Again, this can easily be implemented with specialization by having a blanket impl of Bark on anything that implements Animal that is a noop, and then specifically implementing it on Dog and Wolf. Without specialization, I'm unsure how this could be implemented.

As for why it couldn't be implemented right in the Animal trait, I want to have other crates implement their own Bark-style traits, and inside those crates they need to be able to loop through Vec<Box> and call their Bark-style trait functions.

I hope this makes it more clear what I'm after!

Ah, this is a crucial aspect. There's no good way to do this when the subtraits are unknown. There are some compromised versions.

Rust doesn't have OO-class-like subtyping. There's no way to (attempt to) go directly from &dyn Super to &dyn Sub.[1] You'd have to first find a way from &dyn Super to &ConcreteType and then to &dyn Sub.

One (compromised) way is to go from &dyn Animal to &dyn Any[2] and then try to downcast to each concrete Bark implementor (and then optionally type erase to &dyn Bark).[3]


You can generally go from a subtrait to a supertrait, because the subtrait knows about the supertrait. In this case the implicit trip from &dyn Sub to &ConcreteType is part of the compiler's implementation of Sub for dyn Sub. Then you get the concrete-typed coercion to &dyn Super.[4] However, that's the opposite direction from what you want.

You can do a similar dance for any trait the supertrait knows about, actually, though less ergonomically -- implementors may have to supply a method since you can't cover everything in a blanket implementation. That's the approach of my suggestion above. But when the supertrait doesn't know about the subtrait, there's no way for it to do the last step.[5]


  1. There is no sub/super type relationship despite the terminology. ↩︎

  2. this part is possible... at the cost of a 'static restriction on Animal ↩︎

  3. Or use a type map or something, which similarly maps a finite set of types. ↩︎

  4. This will probably even be possible to do directly in the nearish future, without the trip through &ConcreteType. But that will be because subtrait vtables in some sense contain the suptertrait vtable. ↩︎

  5. Technically you can sorta do it on nightly, but that's not object safe -- you can't make a dyn Animal anymore -- because (vtable considerations aside) the base type, and thus the satisfied bounds, are erased. There's no runtime querying of an erased type's (other) trait implementations. ↩︎

One way I was thinking of going about this (super hacky) was to have a function on Animal called custom() like so:

trait Animal {
    fn custom(&self, key: &str) -> Option<Box<dyn Any>>;
}

That way on dog, I can implement something like this:

trait Bark {
    fn bark(&self);
}

impl Bark for Dog {
    fn bark(&self) {...}
}

struct BarkWrapper(Box<dyn Bark>);

impl Animal for Dog {
    fn custom(&self, key: &str) -> Option<Box<dyn Any>> {
        if key == "bark" {
            Some(BarkWrapper(Box::new(self.clone)))
        } else {
            None
        }
    }
}

And then in the loop, other crates can do something like this:

let animals: Vec<Box<dyn Animal>> = vec![Box::new(Dog), Box::new(Cat)];
for animal in animals {
    if let Some(returned_any) = animal.custom("bark") { // Will return a Box<dyn Any>
        if let Some(bark_wrapper) = returned_any.downcast_ref::<BarkWrapper>() {
            wrapper.0.bark();
        }
    }
}

Don't know if this works yet, will try it later.

I think I got it!

Let me know if there's anything obviously wrong with this, but it seems to work on my end:

use std::{any::Any, fmt::Debug};

fn main() {
    let objects: Vec<Box<dyn Animal>> = vec![Box::new(Dog), Box::new(Cat)];

    for obj in objects {
        if let Some(any) = obj.custom("bark") {
            if let Some(bark_wrapper) = any.downcast_ref::<BarkWrapper>() {
                bark_wrapper.0.bark();
            }
        }
    }
}

#[derive(Debug, Clone)]
struct Dog;
#[derive(Debug)]
struct Cat;

trait Animal {
    #[allow(unused)]
    fn custom(&self, key: &str) -> Option<Box<dyn Any>> {
        None
    }
}
impl Animal for Dog {
    fn custom(&self, key: &str) -> Option<Box<dyn Any>> {
        if key == "bark" {
            Some(Box::new(BarkWrapper(Box::new(self.clone()))))
        } else {
            None
        }
    }
}
impl Animal for Cat {}

trait BarkTrait {
    fn bark(&self);
}

impl BarkTrait for Dog {
    fn bark(&self) {
        println!("woof");
    }
}

struct BarkWrapper(Box<dyn BarkTrait>);

Even though it's hacky because of the custom(key) and BarkWrapper thing, it works for me since

  1. Animal doesn't need to know about Bark
  2. Other subtraits can be added like Meow, Growl, etc. from other crates
  3. It allows those other crates to loop through a Vec<Box> and pull out the structs that implement the subtrait

I don't like that I need to clone Dog here, but I think there's a way to do this by having BarkWrapper be struct BarkWrapper<'a>(&'a dyn Bark) but I haven't worked out the lifetimes yet in the traits.

If Box<dyn Bark> implements Bark and Animal, you could do double-boxing and downcast to Box<dyn Bark>. That's cleaner than most (safe[1]) solutions.

Well, the crate is unmaintained and unsound. I didn't dig much deeper than that.


  1. and sound ↩︎

1 Like

True, turns out I don't need it anyway. Edited my answer.

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.