Hi there,
I'm new to Rust from a Python background, and wanted to implement this tiny autograd library in Rust as a learning experience. It is not going well.
First step is to implement the Value
class as a struct, but I'm struggling with the backward
field, whose type is some kind of function.
Here's what I have so far:
pub struct Value{
pub data: f32,
pub grad: f32,
pub backward: fn()->(),
}
fn backward(){
println!("Backward implemented!");
}
impl Value {
pub fn new(data:f32) -> Value{
Value {
data: data,
grad: 0.0,
backward: || -> () {},
}
}
}
impl Add for Value {
type Output = Self;
fn add(self, other: Self) -> Self {
Self::new(self.data + other.data);
}
}
This compiles fine, but the backward
field is wrong when a Value
is constructed as the sum of two other values. So in the case
let v1:Value = Value::new(1.0);
let v2:Value = Value::new(2.3);
let v3:Value = v1 + v2
I want v3.backward
to have functionality like
fn v3.backward()->(){
v1.grad += v3.grad;
v2.grad += v3.grad;
v1.backward();
v2.backward();
}
How can I go about this? I tried writing an add_backward
function and assigning it to v3.backward
in the implementation of Add
for Value
like so:
fn add_backward(grad:f32, left: &Value, right: &Value){
left.grad += grad;
right.grad += grad;
(left.backward)(grad,left,right);
(right.backward)(grad,left,right);
}
impl Add for Value {
type Output = Self;
fn add(self, other: Self) -> Self {
let mut out = Self::new(self.data + other.data);
out.backward = add_backward;
out
}
}
But the type signatures don't match - how can I set a function in a struct field that takes completely arbitrary arguments? Is this even the right way of going about it? I thought about using a closure in the impl Add
bit, but don't know how to type that in the struct definition .
Note that I think backward
can't be implemented as a method or a trait, as it depends on how the specific instance of Value
is created.
Any help will be greatly appreciated!