Generic function parameter and closure

I have a TreeSearcher structure as below

pub struct TreeSearcher<F>
where
    F: Fn(Tensor<f32>) -> (Tensor<f32>, f32),
{
    c_puct: f32,
    root: Rc<RefCell<TreeNode>>,
    predict_fn: F,
}

impl<F> TreeSearcher<F>
where
    F: Fn(Tensor<f32> /*(1,4,15,15)*/) -> (Tensor<f32> /*(1,15*15)*/, f32),
{
    pub fn new(c_puct: f32, predict_fn: F) -> Self {
        Self {
            c_puct: c_puct,
            root: TreeNode::new(1f32),
            predict_fn: predict_fn,
        }
    }
}

Now I declare a Trainer structure using the TreeSearcher

pub struct Trainer {
    model: PolicyValueModel,
    searcher: TreeSearcher<fn(Tensor<f32>) -> (Tensor<f32>, f32)>,
}

impl Trainer {
    pub fn new() -> Self {
        let model = get_best_model();

        let mut searcher = TreeSearcher::new(5f32, |state_batch| {
            model.predict(&state_batch).expect("Failed to predict")
        });
        Self {
            model: model,
            searcher: searcher,  // <---- expected fn pointer, found closure
        }
    }
}

= note: expected struct mcts::TreeSearcher<fn(tensorflow::Tensor<f32>) -> (tensorflow::Tensor<f32>, f32)>
found struct mcts::TreeSearcher<[closure@src/train.rs:13:52: 15:10]>

Please how can I solve this error?

You're trying to pass a closure, but you declared TreeSearcher to take a function pointer. You can use closure syntax to create a function pointer like that only if the closure doesn't capture any environment. Yours captures model though, so it can't be coerced to a function pointer.

You're also trying to reference model from the closure AND store model in Trainer which isn't going to work. If you don't need to hold on to it you can just get rid of the struct field.

You can just use a type parameter on Trainer to get around the closure issue, though since you need the type parameter to be a closure things do get a bit messy.

Playground

pub struct Trainer<F>
where
    F: Fn(Tensor<f32>) -> (Tensor<f32>, f32),
{
    searcher: TreeSearcher<F>,
}

// Using `fn(Tensor<f32>) -> (Tensor<f32>, f32)` here so type inference doesn't fail when you try and call `Trainer::new()`.
// I don't think there's a way to use `impl Trait` on an impl on stable right now.
impl Trainer<fn(Tensor<f32>) -> (Tensor<f32>, f32)> {
    pub fn new() -> Trainer<impl Fn(Tensor<f32>) -> (Tensor<f32>, f32)> {
        let model = get_best_model();

        let searcher = TreeSearcher::new(5f32, move |state_batch| {
            model.predict(&state_batch).expect("Failed to predict")
        });
        Trainer {
            searcher, // <---- expected fn pointer, found closure
        }
    }
}

impl Fn(Tensor<f32>) -> (Tensor<f32>, f32) means "theres a single type that implements the trait Fn(Tensor<f32>) -> (Tensor<f32>, f32), and that type can be inferred from the body of this function" which allows us to put a closure there without having to name the closure type (which is good because closures don't HAVE type names)

The method is implemented on Trainer<fn(Tensor<f32>) -> (Tensor<f32>, f32)> instead of generically for Trainer<F> to avoid errors when calling new. If we did

impl<F> Trainer<F> where F: Fn(Tensor<f32>) -> (Tensor<f32>, f32) {
// same as before
}

We'd get an error when we tried to used new because F doesn't have any way to be inferred. Trainer's type parameter is determined internally by new, not by F so we'd have to call new with the turbofish Trainer::<fn(Tensor<f32>) -> (Tensor<f32>, f32)>::new() to satisfy the compiler.

If you actually need mutable access to model in both places, you'll have to use another Rc & RefCell combo there to access it in both places.

2 Likes