What type to use for function in a struct?

Hi there,

I'm new to Rust from a Python background, and wanted to implement this tiny autograd library in Rust as a learning experience. It is not going well.

First step is to implement the Value class as a struct, but I'm struggling with the backward field, whose type is some kind of function.

Here's what I have so far:

    pub struct Value{
        pub data: f32,
        pub grad: f32,
        pub backward: fn()->(),
    }

    fn backward(){
        println!("Backward implemented!");
    }


    impl Value {
        pub fn new(data:f32) -> Value{
            Value {
                data: data,
                grad: 0.0,
                backward: || -> () {},
            }
        }
    }

    impl Add for Value {
        type Output = Self;

        fn add(self, other: Self) -> Self {
            Self::new(self.data + other.data);
        }
    }

This compiles fine, but the backward field is wrong when a Value is constructed as the sum of two other values. So in the case

let v1:Value = Value::new(1.0);
let v2:Value = Value::new(2.3);
let v3:Value = v1 + v2

I want v3.backward to have functionality like

fn v3.backward()->(){
    v1.grad += v3.grad;
    v2.grad += v3.grad;
    v1.backward();
    v2.backward();
}

How can I go about this? I tried writing an add_backward function and assigning it to v3.backward in the implementation of Add for Value like so:

fn add_backward(grad:f32, left: &Value, right: &Value){
    left.grad += grad;
    right.grad += grad;
    (left.backward)(grad,left,right);
    (right.backward)(grad,left,right);
}

    impl Add for Value {
        type Output = Self;

        fn add(self, other: Self) -> Self {
            let mut out = Self::new(self.data + other.data);
            out.backward = add_backward;
            out
        }
    }

But the type signatures don't match - how can I set a function in a struct field that takes completely arbitrary arguments? Is this even the right way of going about it? I thought about using a closure in the impl Add bit, but don't know how to type that in the struct definition :man_shrugging:.

Note that I think backward can't be implemented as a method or a trait, as it depends on how the specific instance of Value is created.

Any help will be greatly appreciated!

One option would be to pass the struct itself as an argument to the backward function, and store the other necessary data within the struct:

pub struct Value{
    pub data: f32,
    pub grad: f32,
    pub backward: fn(&mut Value),
    pub left: Option<Box<Value>>,
    pub right: Option<Box<Value>>,
}

impl Value {
    pub fn new(data:f32) -> Value{
        Value {
            data: data,
            grad: 0.0,
            left: None,
            right: None,
            backward: |_| {},
        }
    }
}

Then the add function looks like this:

fn add(self, other: Self) -> Self {
    Value {
        data: self.data + other.data,
        grad: 0.0,
        left: Some(Box::new(self)),
        right: Some(Box::new(other)),
        backward: |this| {
            if let Some(left) = &mut this.left {
                left.grad += this.grad;
                (left.backward)(left);
            }
            if let Some(right) = &mut this.right {
                right.grad += this.grad;
                (right.backward)(right);    
            }
        }
    }
}

Another option might be to use boxed closures, but it might get tricky letting a closure mutate the struct that it's stored within.

Thanks @mbrubeck - that's solved all my compilation problems.

Will this cause each Value created as the sum of two other Values to have copies of them and their data? Not 100% clear on how Box works.

Yes, this causes the Value returned by add to store the two Values that were added.

Ah - that might be a problem in the future. There'll be Values that are linear combinations of many other Values, which are themselves sums of other Values (and so on). Can I make sure these are references instead? How would dereferencing to update their data and grad work in the closure?