Hi
I am trying to code a neural network from scratch. My data structure is a tree with each node wrapping an f32
.
#[derive(Clone, Debug)]
pub struct Scalar {
node: Rc<RefCell<Node>>,
}
#[derive(Clone, Debug)]
struct Node {
value: f32,
children: Vec<Rc<Scalar>>,
parent: Weak<Scalar>,
label: String,
}
These are created with the following code:
impl Scalar {
pub fn new(value: f32) -> Self {
Self {
node: Rc::new(RefCell::new(Node::new(value))),
}
}
pub fn with_label(&self, label: &str) -> Self {
{
self.node.borrow_mut().label = label.to_string();
}
self.clone()
}
}
impl Node {
fn new(value: f32) -> Self {
Self {
value,
children: vec![],
parent: Weak::new(),
label: "".to_string(),
}
}
}
When performing a basic maths operation such as adding, for example c = a + b
, the operands a
and b
become children of c
with strong references, while each of a
and b
link back to c
with weak references. This is so that I can later implement automatic function differentiation for back propagation. I implement Add
for the Scalar
like this:
impl Add<&Scalar> for &Scalar {
type Output = Scalar;
fn add(self, rhs: &Scalar) -> Self::Output {
let sum = self.node.borrow().value + rhs.node.borrow().value;
let scalar = Scalar::new(sum);
// make children
scalar.node.borrow_mut().children = vec![Rc::new(self.clone()), Rc::new(rhs.clone())];
// children need parents
// FIXME: THIS DOES NOT WORK
self.node.borrow_mut().parent = Rc::downgrade(&Rc::new(scalar.clone()));
rhs.node.borrow_mut().parent = Rc::downgrade(&Rc::new(scalar.clone()));
eprintln!(
"trying to upgrade............. {:?}",
self.node.borrow().parent
);
scalar
}
}
The weak references back to the parent always end up returning None
when I upgrade it. Reading the documentation and the chapter on reference cycles from the Rust book, this happens when the parent goes out of scope. I cannot identify why this happens and I think I am running around in circles trying to find a solution. My test case is the following:
#[cfg(test)]
mod tests {
use super::*;
fn float_compare(a: f32, b: f32) -> bool {
const TOLERANCE: f32 = 1e-6;
(a - b).abs() < TOLERANCE
}
#[test]
fn add_two_scalars() {
let a = Scalar::new(3.1).with_label("a");
let b = Scalar::new(4.2).with_label("b");
let c = (&a + &b).with_label("c");
assert!(float_compare(c.node.borrow().value, 7.3));
let label_child_0 = c.node.borrow().children[0].node.borrow().label.clone();
let label_child_1 = c.node.borrow().children[1].node.borrow().label.clone();
assert_eq!(label_child_0, "a");
assert_eq!(label_child_1, "b");
assert!(float_compare(a.node.borrow().value, 3.1));
// TODO: this should not work
assert!(a.node.borrow().parent.upgrade().is_none());
}
}
I would like help in identifying what I do wrong with my code. Suggestions for crates that do similar tasks are appreciated, but I would prefer to solve this using as few dependencies as possible because my main goal is to expose myself to the standard library and learn to be a better Rust programmer. This is the first time I am learning about interior mutability (I come from the world of Python) so if there is a better way to be more idiomatic in Rust please feel free to critique my code.
Thanks!