Trait object with associated types

Hello, I am attempting to use a trait object to allow a field in a struct to be a vector of heterogeneous types. The elements in the vector are rules (implemented as structs) with common methods that can be called. The rules are not known at compile time as the rules are passed by a Python user and mapped to the appropriate structs. Additionally, each rule may operate on a different primitive type such as f64 or i32

I am getting the following compilation error when defining the type of rules in the struct Rules:

the value of the associated type TinSplitRule must be specified

use rand::Rng;
use std::f64;

/// Interface for split strategies.
trait SplitRule {
    type T;

    fn get_split_value(&self, candidates: &[Self::T]) -> Option<Self::T>;
    fn divide(&self, candidates: &[Self::T], split_value: Self::T) -> Vec<bool>;
}

/// Standard continuous split rule. Pick a pivot value and split
/// depending on if variable is smaller or greater than the value picked.
struct ContinuousSplit;

impl SplitRule for ContinuousSplit {
    type T = f64;

    fn get_split_value(&self, candidates: &[f64]) -> Option<f64> {
        if candidates.len() > 1 {
            let idx = rand::thread_rng().gen_range(0..candidates.len());
            Some(candidates[idx])
        } else {
            None
        }
    }

    fn divide(&self, candidates: &[f64], split_value: f64) -> Vec<bool> {
        candidates.iter().map(|&x| x <= split_value).collect()
    }
}

struct Rules {
    // rustc: the value of the associated type `T` in `SplitRule` must be specified
    rules: Vec<Box<SplitRule>>,
}

fn main() {

    // This is actually the Python user input
    let rules = Rules {
        rules: vec![
            Box::new(ContinuousSplit),
            Box::new(ContinuousSplit),
            Box::new(<different rule here>),
        ],
    };

   let feature = 1 as usize;
   let selected_rule = &rules[feature];

   let split_value = match rule.get_split_value(<vector of candidate points>) {
            Some(value) => value,
            None => {
                return false;
            }
        };

   // split_value is used later in the program...    
}

I have a working implementation (see below) use dyn Any, but this just feels "wrong". I suppose you could also use an enum, but then there is a lot of control flow in each implementation and kind of defeats the purpose of a "generic" type.

I haven't had any luck with the following topics:

Working implementation using dyn Any

use rand::Rng;
use std::any::Any;
use std::f64;
use std::iter::Iterator;

enum SplitValue {
    Float(f64),
    Integer(i32),
}

/// Interface for split strategies.
pub trait SplitRule: Send + Sync {
    fn as_any(&self) -> &dyn Any;
    fn get_split_value_dyn(&self, candidates: &dyn Any) -> Option<SplitValue>;
    fn divide_dyn(
        &self,
        candidates: &dyn Any,
        split_value: &SplitValue,
    ) -> (Vec<usize>, Vec<usize>);
}

/// Standard continuous split rule. Pick a pivot value and split
/// depending on if variable is smaller or greater than the value picked.
pub struct ContinuousSplit;

impl SplitRule for ContinuousSplit {
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn get_split_value_dyn(&self, candidates: &dyn Any) -> Option<SplitValue> {
        if let Some(candidates) = candidates.downcast_ref::<Vec<f64>>() {
            if candidates.len() > 1 {
                let idx = rand::thread_rng().gen_range(0..candidates.len());
                Some(SplitValue::Float(candidates[idx]))
            } else {
                None
            }
        } else {
            None
        }
    }

    fn divide_dyn(
        &self,
        candidates: &dyn Any,
        split_value: &SplitValue,
    ) -> (Vec<usize>, Vec<usize>) {
        if let Some(candidates) = candidates.downcast_ref::<Vec<f64>>() {
            match split_value {
                SplitValue::Float(threshold) => {
                    let (left, right): (Vec<_>, Vec<_>) =
                        (0..candidates.len()).partition(|&idx| candidates[idx] <= *threshold);
                    (left, right)
                }
                SplitValue::Integer(threshold) => {
                    let threshold = *threshold as f64;
                    let (left, right): (Vec<_>, Vec<_>) =
                        (0..candidates.len()).partition(|&idx| candidates[idx] <= threshold);
                    (left, right)
                }
            }
        } else {
            (vec![], vec![])
        }
    }
}

struct Rules {
    // Compiler is happy...
    rules: Vec<Box<SplitRule>>,
}

Thanks for the help and thoughts.

The combination of the first two parts of the quotes requires some sort of branching somewhere. I'm not sure where you were imaging the enum, but you could have an enum with a variant for every Box<dyn SplitRule<T = primitive>>. Then implementations don't have to care, instead of downcasting or matching enums.[1]

Because Rust is strictly typed, you'll need to know the type of your candidate points before you create them anyway.


  1. The two are about the same level of noisy when looking for one specific type/variant, and downcasting is worse when looking for more than one. ↩︎

Thanks for the quick reply @quinedot

The combination of the first two parts of the quotes requires some sort of branching somewhere

This is the realization I have also came to. Thanks for the thought in regard to the enum. I'll come back here if I find a more "elegant" solution. It works for now.

Thanks!