Hello, I am attempting to use a trait object to allow a field in a struct to be a vector of heterogeneous types. The elements in the vector are rules (implemented as structs) with common methods that can be called. The rules are not known at compile time as the rules are passed by a Python user and mapped to the appropriate structs. Additionally, each rule may operate on a different primitive type such as f64
or i32
I am getting the following compilation error when defining the type of rules
in the struct Rules
:
the value of the associated type
Tin
SplitRule must be specified
use rand::Rng;
use std::f64;
/// Interface for split strategies.
trait SplitRule {
type T;
fn get_split_value(&self, candidates: &[Self::T]) -> Option<Self::T>;
fn divide(&self, candidates: &[Self::T], split_value: Self::T) -> Vec<bool>;
}
/// Standard continuous split rule. Pick a pivot value and split
/// depending on if variable is smaller or greater than the value picked.
struct ContinuousSplit;
impl SplitRule for ContinuousSplit {
type T = f64;
fn get_split_value(&self, candidates: &[f64]) -> Option<f64> {
if candidates.len() > 1 {
let idx = rand::thread_rng().gen_range(0..candidates.len());
Some(candidates[idx])
} else {
None
}
}
fn divide(&self, candidates: &[f64], split_value: f64) -> Vec<bool> {
candidates.iter().map(|&x| x <= split_value).collect()
}
}
struct Rules {
// rustc: the value of the associated type `T` in `SplitRule` must be specified
rules: Vec<Box<SplitRule>>,
}
fn main() {
// This is actually the Python user input
let rules = Rules {
rules: vec![
Box::new(ContinuousSplit),
Box::new(ContinuousSplit),
Box::new(<different rule here>),
],
};
let feature = 1 as usize;
let selected_rule = &rules[feature];
let split_value = match rule.get_split_value(<vector of candidate points>) {
Some(value) => value,
None => {
return false;
}
};
// split_value is used later in the program...
}
I have a working implementation (see below) use dyn Any
, but this just feels "wrong". I suppose you could also use an enum, but then there is a lot of control flow in each implementation and kind of defeats the purpose of a "generic" type.
I haven't had any luck with the following topics:
- Trait objects with associated types - #20 by matthewhammer
- Trait Objects with Associated Types part 2 - #2 by vitalyd
Working implementation using dyn Any
use rand::Rng;
use std::any::Any;
use std::f64;
use std::iter::Iterator;
enum SplitValue {
Float(f64),
Integer(i32),
}
/// Interface for split strategies.
pub trait SplitRule: Send + Sync {
fn as_any(&self) -> &dyn Any;
fn get_split_value_dyn(&self, candidates: &dyn Any) -> Option<SplitValue>;
fn divide_dyn(
&self,
candidates: &dyn Any,
split_value: &SplitValue,
) -> (Vec<usize>, Vec<usize>);
}
/// Standard continuous split rule. Pick a pivot value and split
/// depending on if variable is smaller or greater than the value picked.
pub struct ContinuousSplit;
impl SplitRule for ContinuousSplit {
fn as_any(&self) -> &dyn Any {
self
}
fn get_split_value_dyn(&self, candidates: &dyn Any) -> Option<SplitValue> {
if let Some(candidates) = candidates.downcast_ref::<Vec<f64>>() {
if candidates.len() > 1 {
let idx = rand::thread_rng().gen_range(0..candidates.len());
Some(SplitValue::Float(candidates[idx]))
} else {
None
}
} else {
None
}
}
fn divide_dyn(
&self,
candidates: &dyn Any,
split_value: &SplitValue,
) -> (Vec<usize>, Vec<usize>) {
if let Some(candidates) = candidates.downcast_ref::<Vec<f64>>() {
match split_value {
SplitValue::Float(threshold) => {
let (left, right): (Vec<_>, Vec<_>) =
(0..candidates.len()).partition(|&idx| candidates[idx] <= *threshold);
(left, right)
}
SplitValue::Integer(threshold) => {
let threshold = *threshold as f64;
let (left, right): (Vec<_>, Vec<_>) =
(0..candidates.len()).partition(|&idx| candidates[idx] <= threshold);
(left, right)
}
}
} else {
(vec![], vec![])
}
}
}
struct Rules {
// Compiler is happy...
rules: Vec<Box<SplitRule>>,
}
Thanks for the help and thoughts.