I have a TreeSearcher
structure as below
pub struct TreeSearcher<F>
where
F: Fn(Tensor<f32>) -> (Tensor<f32>, f32),
{
c_puct: f32,
root: Rc<RefCell<TreeNode>>,
predict_fn: F,
}
impl<F> TreeSearcher<F>
where
F: Fn(Tensor<f32> /*(1,4,15,15)*/) -> (Tensor<f32> /*(1,15*15)*/, f32),
{
pub fn new(c_puct: f32, predict_fn: F) -> Self {
Self {
c_puct: c_puct,
root: TreeNode::new(1f32),
predict_fn: predict_fn,
}
}
}
Now I declare a Trainer
structure using the TreeSearcher
pub struct Trainer {
model: PolicyValueModel,
searcher: TreeSearcher<fn(Tensor<f32>) -> (Tensor<f32>, f32)>,
}
impl Trainer {
pub fn new() -> Self {
let model = get_best_model();
let mut searcher = TreeSearcher::new(5f32, |state_batch| {
model.predict(&state_batch).expect("Failed to predict")
});
Self {
model: model,
searcher: searcher, // <---- expected fn pointer, found closure
}
}
}
= note: expected struct
mcts::TreeSearcher<fn(tensorflow::Tensor<f32>) -> (tensorflow::Tensor<f32>, f32)>
found structmcts::TreeSearcher<[closure@src/train.rs:13:52: 15:10]>
Please how can I solve this error?