It seems smartcore
is fairly "popular" as far as Rust machine learning crates go, but I'm struggling to find more than a couple of bare examples. Like the title says, I'm trying to perform grid search parameter tuning with cross-validation on a random forest classifier, but I'm getting an error in the cross_validate()
function. Here's the excerpt of the function I'm having an issue with (ignoring the return type and actual search values since I'm planning on updating those):
use smartcore::linalg::basic::arrays::{MutArray, Array2};
use smartcore::linalg::basic::matrix::DenseMatrix;
use smartcore::metrics::{mean_squared_error, accuracy};
use smartcore::model_selection::{train_test_split, cross_validate, KFold};
use smartcore::ensemble::random_forest_classifier::*;
fn grid_search(x_train: &DenseMatrix<f64>, y_train: &Vec<u32>) -> Option<(i32, u16, usize, usize)> {
let mut best_score = 0.;
let mut best_params = None;
let n_trees = vec![10, 50, 100, 200, 500];
let max_depth = vec![2, 4, 6, 8, 10];
let min_samples_split = vec![2, 4, 6, 8, 10];
let min_samples_leaf = vec![1, 2, 4, 6, 8, 10];
let max_features = vec![2, 4, 6, 8, 10];
for n_tree in n_trees {
for m_depth in &max_depth {
for m_split in &min_samples_split {
for m_leaf in &min_samples_leaf {
for m_feat in &max_features {
let cv_score = cross_validate(
RandomForestClassifier::fit,
x_train,
y_train,
Default::default()
.with_n_trees(*m_feat)
.with_max_depth(*m_depth)
.with_min_samples_split(*m_split)
.with_min_samples_leaf(*m_leaf),
&KFold::default().with_n_splits(10),
&accuracy
).unwrap();
if cv_score.mean_test_score() > best_score {
best_score = cv_score.mean_test_score();
best_params = Some((n_tree, m_depth, m_split, m_leaf));
}
}
}
}
}
}
println!("Best score: {}", best_score);
println!("Best params: {:?}", best_params);
Some(
(
best_params.unwrap().0,
*best_params.unwrap().1,
*best_params.unwrap().2,
*best_params.unwrap().3
)
)
}
I'm having a hard time interpreting the error message that it's giving me since it says that the fit function within the cross_validate
function needs to have the trait SupervisedEstimator
, which, as far as I can tell, it does.
the trait bound `for<'r, 's> fn(&'r _, &'s _, smartcore::ensemble::random_forest_classifier::RandomForestClassifierParameters) -> std::result::Result<smartcore::ensemble::random_forest_classifier::RandomForestClassifier<_, _, _, _>, smartcore::error::Failed> {smartcore::ensemble::random_forest_classifier::RandomForestClassifier::<_, _, _, _>::fit}: smartcore::api::SupervisedEstimator<_, _, _>` is not satisfied
the following other types implement trait `smartcore::api::SupervisedEstimator<X, Y, P>`:
<smartcore::ensemble::random_forest_classifier::RandomForestClassifier<TX, TY, X, Y> as smartcore::api::SupervisedEstimator<X, Y, smartcore::ensemble::random_forest_classifier::RandomForestClassifierParameters>>
<smartcore::ensemble::random_forest_regressor::RandomForestRegressor<TX, TY, X, Y> as smartcore::api::SupervisedEstimator<X, Y, smartcore::ensemble::random_forest_regressor::RandomForestRegressorParameters>>
<smartcore::linear::elastic_net::ElasticNet<TX, TY, X, Y> as smartcore::api::SupervisedEstimator<X, Y, smartcore::linear::elastic_net::ElasticNetParameters>>
<smartcore::linear::lasso::Lasso<TX, TY, X, Y> as smartcore::api::SupervisedEstimator<X, Y, smartcore::linear::lasso::LassoParameters>>
<smartcore::linear::linear_regression::LinearRegression<TX, TY, X, Y> as smartcore::api::SupervisedEstimator<X, Y, smartcore::linear::linear_regression::LinearRegressionParameters>>
<smartcore::linear::logistic_regression::LogisticRegression<TX, TY, X, Y> as smartcore::api::SupervisedEstimator<X, Y, smartcore::linear::logistic_regression::LogisticRegressionParameters<TX>>>
<smartcore::linear::ridge_regression::RidgeRegression<TX, TY, X, Y> as smartcore::api::SupervisedEstimator<X, Y, smartcore::linear::ridge_regression::RidgeRegressionParameters<TX>>>
<smartcore::naive_bayes::bernoulli::BernoulliNB<TX, TY, X, Y> as smartcore::api::SupervisedEstimator<X, Y, smartcore::naive_bayes::bernoulli::BernoulliNBParameters<TX>>>
and 7 others
main.rs(102, 40): required by a bound introduced by this call
mod.rs(239, 8): required by a bound in `smartcore::model_selection::cross_validate`
I realize this is probably more specific to smartcore
than Rust, but I figured I'd post here in the hopes that someone might be able to help. If this is too specific, I'll create a GitHub issue on the repo.